diff --git a/presto-array/src/main/java/io/prestosql/array/IntBigArray.java b/presto-array/src/main/java/io/prestosql/array/IntBigArray.java index 5d5c52f88..84fff1d6b 100644 --- a/presto-array/src/main/java/io/prestosql/array/IntBigArray.java +++ b/presto-array/src/main/java/io/prestosql/array/IntBigArray.java @@ -190,11 +190,11 @@ public void restore(Object state, BlockEncodingSerdeProvider serdeProvider) this.segments = myState.segments; } - private static class IntBigArrayState + public static class IntBigArrayState implements Serializable { - private int[][] array; - private int capacity; - private int segments; + public int[][] array; + public int capacity; + public int segments; } } diff --git a/presto-array/src/main/java/io/prestosql/array/LongBigArray.java b/presto-array/src/main/java/io/prestosql/array/LongBigArray.java index 3dc64e38a..1853e2ca3 100644 --- a/presto-array/src/main/java/io/prestosql/array/LongBigArray.java +++ b/presto-array/src/main/java/io/prestosql/array/LongBigArray.java @@ -182,11 +182,11 @@ public void restore(Object state, BlockEncodingSerdeProvider serdeProvider) this.segments = myState.segments; } - private static class LongBigArrayState + public static class LongBigArrayState implements Serializable { - private long[][] array; - private int capacity; - private int segments; + public long[][] array; + public int capacity; + public int segments; } } diff --git a/presto-benchmark/src/main/java/io/prestosql/benchmark/HandTpchQuery1.java b/presto-benchmark/src/main/java/io/prestosql/benchmark/HandTpchQuery1.java index 83c2d4535..cc6027140 100644 --- a/presto-benchmark/src/main/java/io/prestosql/benchmark/HandTpchQuery1.java +++ b/presto-benchmark/src/main/java/io/prestosql/benchmark/HandTpchQuery1.java @@ -130,7 +130,8 @@ protected List createOperatorFactories() 10_000, Optional.of(new DataSize(16, MEGABYTE)), JOIN_COMPILER, - false); + false, + Optional.empty()); return ImmutableList.of(tableScanOperator, tpchQuery1Operator, aggregationOperator); } diff --git a/presto-benchmark/src/main/java/io/prestosql/benchmark/HashAggregationBenchmark.java b/presto-benchmark/src/main/java/io/prestosql/benchmark/HashAggregationBenchmark.java index 9392081ed..67d16d4bc 100644 --- a/presto-benchmark/src/main/java/io/prestosql/benchmark/HashAggregationBenchmark.java +++ b/presto-benchmark/src/main/java/io/prestosql/benchmark/HashAggregationBenchmark.java @@ -65,7 +65,8 @@ protected List createOperatorFactories() 100_000, Optional.of(new DataSize(16, MEGABYTE)), JOIN_COMPILER, - false); + false, + Optional.empty()); return ImmutableList.of(tableScanOperator, aggregationOperator); } diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/statistics/MetastoreHiveStatisticsProvider.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/statistics/MetastoreHiveStatisticsProvider.java index 9cfe135b6..001c571ae 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/statistics/MetastoreHiveStatisticsProvider.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/statistics/MetastoreHiveStatisticsProvider.java @@ -52,6 +52,7 @@ import java.math.BigDecimal; import java.time.LocalDate; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Comparator; import java.util.List; @@ -62,6 +63,7 @@ import java.util.OptionalLong; import java.util.Set; +import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -159,10 +161,10 @@ else if (sample != null) { try { Map statisticsSample = statisticsProvider.getPartitionsStatistics(session, schemaTableName, partitionsSample, table); if (!includeColumnStatistics) { - OptionalDouble averageRows = calculateAverageRowsPerPartition(statisticsSample.values()); + Optional averageRows = calculatePartitionsRowCount(statisticsSample.values(), partitions.size()); TableStatistics.Builder result = TableStatistics.builder(); if (averageRows.isPresent()) { - result.setRowCount(Estimate.of(averageRows.getAsDouble() * partitions.size())); + result.setRowCount(Estimate.of(averageRows.get().getRowCount())); } result.setFileCount(calulateFileCount(statisticsSample.values())); result.setOnDiskDataSizeInBytes(calculateTotalOnDiskSizeInBytes(statisticsSample.values())); @@ -433,14 +435,12 @@ private static TableStatistics getTableStatistics( checkArgument(!partitions.isEmpty(), "partitions is empty"); - OptionalDouble optionalAverageRowsPerPartition = calculateAverageRowsPerPartition(statistics.values()); - if (!optionalAverageRowsPerPartition.isPresent()) { + Optional optionalRowCount = calculatePartitionsRowCount(statistics.values(), partitions.size()); + if (!optionalRowCount.isPresent()) { return TableStatistics.empty(); } - double averageRowsPerPartition = optionalAverageRowsPerPartition.getAsDouble(); - verify(averageRowsPerPartition >= 0, "averageRowsPerPartition must be greater than or equal to zero"); - int queriedPartitionsCount = partitions.size(); - double rowCount = averageRowsPerPartition * queriedPartitionsCount; + + double rowCount = optionalRowCount.get().getRowCount(); TableStatistics.Builder result = TableStatistics.builder(); long fileCount = calulateFileCount(statistics.values()); @@ -457,6 +457,7 @@ private static TableStatistics getTableStatistics( if (columnHandle.isPartitionKey()) { tableColumnStatistics = statsCache.get(partitions.get(0).getTableName().getTableName() + columnName); if (tableColumnStatistics == null || invalidateStatsCache(tableColumnStatistics, Estimate.of(rowCount), fileCount, totalOnDiskSize)) { + double averageRowsPerPartition = optionalRowCount.get().getAverageRowsPerPartition(); columnStatistics = createPartitionColumnStatistics(columnHandle, columnType, partitions, statistics, averageRowsPerPartition, rowCount); TableStatistics tableStatistics = new TableStatistics(Estimate.of(rowCount), fileCount, totalOnDiskSize, ImmutableMap.of()); tableColumnStatistics = new TableColumnStatistics(tableStatistics, columnStatistics); @@ -485,15 +486,44 @@ private static boolean invalidateStatsCache(TableColumnStatistics tableColumnSta } @VisibleForTesting - static OptionalDouble calculateAverageRowsPerPartition(Collection statistics) + static Optional calculatePartitionsRowCount(Collection statistics, int queriedPartitionsCount) { - return statistics.stream() + long[] rowCounts = statistics.stream() .map(PartitionStatistics::getBasicStatistics) .map(HiveBasicStatistics::getRowCount) .filter(OptionalLong::isPresent) .mapToLong(OptionalLong::getAsLong) .peek(count -> verify(count >= 0, "count must be greater than or equal to zero")) - .average(); + .toArray(); + int sampleSize = statistics.size(); + // Sample contains all the queried partitions, estimate avg normally + if (rowCounts.length <= 2 || queriedPartitionsCount == sampleSize) { + OptionalDouble averageRowsPerPartitionOptional = Arrays.stream(rowCounts).average(); + if (!averageRowsPerPartitionOptional.isPresent()) { + return Optional.empty(); + } + double averageRowsPerPartition = averageRowsPerPartitionOptional.getAsDouble(); + return Optional.of(new PartitionsRowCount(averageRowsPerPartition, averageRowsPerPartition * queriedPartitionsCount)); + } + + // Some partitions (e.g. __HIVE_DEFAULT_PARTITION__) may be outliers in terms of row count. + // Excluding the min and max rowCount values from averageRowsPerPartition calculation helps to reduce the + // possibility of errors in the extrapolated rowCount due to a couple of outliers. + int minIndex = 0; + int maxIndex = 0; + long rowCountSum = rowCounts[0]; + for (int index = 1; index < rowCounts.length; index++) { + if (rowCounts[index] < rowCounts[minIndex]) { + minIndex = index; + } + else if (rowCounts[index] > rowCounts[maxIndex]) { + maxIndex = index; + } + rowCountSum += rowCounts[index]; + } + double averageWithoutOutliers = ((double) (rowCountSum - rowCounts[minIndex] - rowCounts[maxIndex])) / (rowCounts.length - 2); + double rowCount = (averageWithoutOutliers * (queriedPartitionsCount - 2)) + rowCounts[minIndex] + rowCounts[maxIndex]; + return Optional.of(new PartitionsRowCount(averageWithoutOutliers, rowCount)); } static long calulateFileCount(Collection statistics) @@ -932,4 +962,58 @@ interface PartitionsStatisticsProvider { Map getPartitionsStatistics(ConnectorSession session, SchemaTableName schemaTableName, List hivePartitions, Table table); } + + @VisibleForTesting + static class PartitionsRowCount + { + private final double averageRowsPerPartition; + private final double rowCount; + + PartitionsRowCount(double averageRowsPerPartition, double rowCount) + { + verify(averageRowsPerPartition >= 0, "averageRowsPerPartition must be greater than or equal to zero"); + verify(rowCount >= 0, "rowCount must be greater than or equal to zero"); + this.averageRowsPerPartition = averageRowsPerPartition; + this.rowCount = rowCount; + } + + private double getAverageRowsPerPartition() + { + return averageRowsPerPartition; + } + + private double getRowCount() + { + return rowCount; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PartitionsRowCount that = (PartitionsRowCount) o; + return Double.compare(that.averageRowsPerPartition, averageRowsPerPartition) == 0 + && Double.compare(that.rowCount, rowCount) == 0; + } + + @Override + public int hashCode() + { + return Objects.hash(averageRowsPerPartition, rowCount); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("averageRowsPerPartition", averageRowsPerPartition) + .add("rowCount", rowCount) + .toString(); + } + } } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java index 96e33c1f9..8f755edd4 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java @@ -151,8 +151,8 @@ import static io.prestosql.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static io.prestosql.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static io.prestosql.execution.SqlStageExecution.createSqlStageExecution; -import static io.prestosql.execution.scheduler.TestPhasedExecutionSchedule.createTableScanPlanFragment; import static io.prestosql.execution.scheduler.TestSourcePartitionedScheduler.createFixedSplitSource; +import static io.prestosql.execution.scheduler.policy.TestPhasedExecutionSchedule.createTableScanPlanFragment; import static io.prestosql.plugin.hive.HiveColumnHandle.BUCKET_COLUMN_NAME; import static io.prestosql.plugin.hive.HiveColumnHandle.PATH_COLUMN_NAME; import static io.prestosql.plugin.hive.HiveCompressionCodec.NONE; diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java index e102b2cfb..d0d062014 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/statistics/TestMetastoreHiveStatisticsProvider.java @@ -58,13 +58,14 @@ import static io.prestosql.plugin.hive.HiveType.HIVE_LONG; import static io.prestosql.plugin.hive.HiveType.HIVE_STRING; import static io.prestosql.plugin.hive.HiveUtil.parsePartitionValue; -import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateAverageRowsPerPartition; +import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.PartitionsRowCount; import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDataSize; import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDataSizeForPartitioningKey; import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDistinctPartitionKeys; import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDistinctValuesCount; import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateNullsFraction; import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateNullsFractionForPartitioningKey; +import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculatePartitionsRowCount; import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateRange; import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateRangeForPartitioningKey; import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.convertPartitionValueToDouble; @@ -82,6 +83,7 @@ import static io.prestosql.spi.type.VarcharType.VARCHAR; import static java.lang.Double.NaN; import static java.lang.String.format; +import static java.util.Collections.nCopies; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; @@ -240,15 +242,34 @@ public void testValidatePartitionStatistics() } @Test - public void testCalculateAverageRowsPerPartition() - { - assertThat(calculateAverageRowsPerPartition(ImmutableList.of())).isEmpty(); - assertThat(calculateAverageRowsPerPartition(ImmutableList.of(PartitionStatistics.empty()))).isEmpty(); - assertThat(calculateAverageRowsPerPartition(ImmutableList.of(PartitionStatistics.empty(), PartitionStatistics.empty()))).isEmpty(); - assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10))), OptionalDouble.of(10)); - assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), PartitionStatistics.empty())), OptionalDouble.of(10)); - assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), rowsCount(20))), OptionalDouble.of(15)); - assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), rowsCount(20), PartitionStatistics.empty())), OptionalDouble.of(15)); + public void testCalculatePartitionsRowCount() + { + assertThat(calculatePartitionsRowCount(ImmutableList.of(), 0)).isEmpty(); + assertThat(calculatePartitionsRowCount(ImmutableList.of(PartitionStatistics.empty()), 1)).isEmpty(); + assertThat(calculatePartitionsRowCount(ImmutableList.of(PartitionStatistics.empty(), PartitionStatistics.empty()), 2)).isEmpty(); + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10)), 1)) + .isEqualTo(Optional.of(new MetastoreHiveStatisticsProvider.PartitionsRowCount(10, 10))); + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10)), 2)) + .isEqualTo(Optional.of(new PartitionsRowCount(10, 20))); + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), PartitionStatistics.empty()), 2)) + .isEqualTo(Optional.of(new PartitionsRowCount(10, 20))); + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20)), 2)) + .isEqualTo(Optional.of(new PartitionsRowCount(15, 30))); + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20)), 3)) + .isEqualTo(Optional.of(new PartitionsRowCount(15, 45))); + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20), PartitionStatistics.empty()), 3)) + .isEqualTo(Optional.of(new PartitionsRowCount(15, 45))); + + assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(100), rowsCount(1000)), 3)) + .isEqualTo(Optional.of(new PartitionsRowCount((10 + 100 + 1000) / 3.0, 10 + 100 + 1000))); + // Exclude outliers from average row count + assertThat(calculatePartitionsRowCount(ImmutableList.builder() + .addAll(nCopies(10, rowsCount(100))) + .add(rowsCount(1)) + .add(rowsCount(1000)) + .build(), + 50)) + .isEqualTo(Optional.of(new PartitionsRowCount(100, (100 * 48) + 1 + 1000))); } @Test diff --git a/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java b/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java index ee36e4f0f..4a947e7b5 100644 --- a/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java +++ b/presto-main/src/main/java/io/prestosql/SystemSessionProperties.java @@ -99,6 +99,7 @@ public final class SystemSessionProperties public static final String REORDER_JOINS = "reorder_joins"; public static final String JOIN_REORDERING_STRATEGY = "join_reordering_strategy"; public static final String MAX_REORDERED_JOINS = "max_reordered_joins"; + public static final String JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR = "join_multi_clause_independence_factor"; public static final String SKIP_REORDERING_THRESHOLD = "skip_reordering_threshold"; public static final String INITIAL_SPLITS_PER_NODE = "initial_splits_per_node"; public static final String SPLIT_CONCURRENCY_ADJUSTMENT_INTERVAL = "split_concurrency_adjustment_interval"; @@ -134,6 +135,7 @@ public final class SystemSessionProperties public static final String IGNORE_STATS_CALCULATOR_FAILURES = "ignore_stats_calculator_failures"; public static final String MAX_DRIVERS_PER_TASK = "max_drivers_per_task"; public static final String DEFAULT_FILTER_FACTOR_ENABLED = "default_filter_factor_enabled"; + public static final String FILTER_CONJUNCTION_INDEPENDENCE_FACTOR = "filter_conjunction_independence_factor"; public static final String UNWRAP_CASTS = "unwrap_casts"; public static final String SKIP_REDUNDANT_SORT = "skip_redundant_sort"; public static final String PREDICATE_PUSHDOWN_USE_TABLE_PROPERTIES = "predicate_pushdown_use_table_properties"; @@ -201,6 +203,10 @@ public final class SystemSessionProperties public static final String FAULT_TOLERANT_EXECUTION_TASK_MEMORY_GROWTH_FACTOR = "fault_tolerant_execution_task_memory_growth_factor"; public static final String FAULT_TOLERANT_EXECUTION_TASK_MEMORY_ESTIMATION_QUANTILE = "fault_tolerant_execution_task_memory_estimation_quantile"; + public static final String ADAPTIVE_PARTIAL_AGGREGATION_ENABLED = "adaptive_partial_aggregation_enabled"; + public static final String ADAPTIVE_PARTIAL_AGGREGATION_MIN_ROWS = "adaptive_partial_aggregation_min_rows"; + public static final String ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold"; + public static final String RETRY_POLICY = "retry_policy"; public static final String EXCHANGE_FILESYSTEM_BASE_DIRECTORY = "exchange_filesystem_base_directory"; @@ -473,6 +479,15 @@ public SystemSessionProperties( return intValue; }, value -> value), + new PropertyMetadata<>( + JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR, + "Scales the strength of independence assumption for selectivity estimates of multi-clause joins", + DOUBLE, + Double.class, + featuresConfig.getJoinMultiClauseIndependenceFactor(), + false, + value -> validateDoubleRange(value, JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR, 0.0, 1.0), + value -> value), new PropertyMetadata<>( SKIP_REORDERING_THRESHOLD, "Skip reordering joins if the number of joins in the logical plan is greater than this threshold", @@ -674,6 +689,15 @@ public SystemSessionProperties( "use a default filter factor for unknown filters in a filter node", featuresConfig.isDefaultFilterFactorEnabled(), false), + new PropertyMetadata<>( + FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, + "Scales the strength of independence assumption for selectivity estimates of the conjunction of multiple filters", + DOUBLE, + Double.class, + featuresConfig.getFilterConjunctionIndependenceFactor(), + false, + value -> validateDoubleRange(value, FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, 0.0, 1.0), + value -> value), booleanProperty( ENABLE_CROSS_REGION_DYNAMIC_FILTER, "Enable cross region dynamic filtering", @@ -998,6 +1022,21 @@ public SystemSessionProperties( QUERY_RESOURCE_TRACKING, "Query tracking feature enabled for current session", queryManagerConfig.isQueryResourceTracking(), + false), + booleanProperty( + ADAPTIVE_PARTIAL_AGGREGATION_ENABLED, + "When enabled, partial aggregation might be adaptively turned off when it does not provide any performance gain", + featuresConfig.isAdaptivePartialAggregationEnabled(), + false), + longProperty( + ADAPTIVE_PARTIAL_AGGREGATION_MIN_ROWS, + "Minimum number of processed rows before partial aggregation might be adaptively turned off", + featuresConfig.getAdaptivePartialAggregationMinRows(), + false), + doubleProperty( + ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD, + "Ratio between aggregation output and input rows above which partial aggregation might be adaptively turned off", + featuresConfig.getAdaptivePartialAggregationUniqueRowsRatioThreshold(), false)); } @@ -1756,4 +1795,29 @@ public static boolean isDynamicScheduleForGroupedExecution(Session session) { return session.getSystemProperty(DYNAMIC_SCHEDULE_FOR_GROUPED_EXECUTION, Boolean.class); } + + public static double getJoinMultiClauseIndependenceFactor(Session session) + { + return session.getSystemProperty(JOIN_MULTI_CLAUSE_INDEPENDENCE_FACTOR, Double.class); + } + + public static double getFilterConjunctionIndependenceFactor(Session session) + { + return session.getSystemProperty(FILTER_CONJUNCTION_INDEPENDENCE_FACTOR, Double.class); + } + + public static boolean isAdaptivePartialAggregationEnabled(Session session) + { + return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION_ENABLED, Boolean.class); + } + + public static long getAdaptivePartialAggregationMinRows(Session session) + { + return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION_MIN_ROWS, Long.class); + } + + public static double getAdaptivePartialAggregationUniqueRowsRatioThreshold(Session session) + { + return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD, Double.class); + } } diff --git a/presto-main/src/main/java/io/prestosql/cost/ComparisonStatsCalculator.java b/presto-main/src/main/java/io/prestosql/cost/ComparisonStatsCalculator.java index cea5e1f95..cc37fa18c 100644 --- a/presto-main/src/main/java/io/prestosql/cost/ComparisonStatsCalculator.java +++ b/presto-main/src/main/java/io/prestosql/cost/ComparisonStatsCalculator.java @@ -22,7 +22,7 @@ import static io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT; import static io.prestosql.cost.SymbolStatsEstimate.buildFrom; -import static io.prestosql.util.MoreMath.firstNonNaN; +import static io.prestosql.util.MoreMath.averageExcludingNaNs; import static io.prestosql.util.MoreMath.max; import static io.prestosql.util.MoreMath.min; import static java.lang.Double.MAX_VALUE; @@ -263,15 +263,4 @@ private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression( rightExpressionSymbol.ifPresent(symbol -> result.addSymbolStatistics(symbol, rightNullsFiltered)); return result.build(); } - - private static double averageExcludingNaNs(double first, double second) - { - if (isNaN(first) && isNaN(second)) { - return NaN; - } - if (!isNaN(first) && !isNaN(second)) { - return (first + second) / 2; - } - return firstNonNaN(first, second); - } } diff --git a/presto-main/src/main/java/io/prestosql/cost/FilterStatsCalculator.java b/presto-main/src/main/java/io/prestosql/cost/FilterStatsCalculator.java index d8db44759..99b139584 100644 --- a/presto-main/src/main/java/io/prestosql/cost/FilterStatsCalculator.java +++ b/presto-main/src/main/java/io/prestosql/cost/FilterStatsCalculator.java @@ -14,11 +14,14 @@ package io.prestosql.cost; import com.google.common.base.VerifyException; +import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ListMultimap; import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.expressions.LogicalRowExpressions; import io.prestosql.metadata.Metadata; +import io.prestosql.spi.PrestoException; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.function.FunctionMetadata; import io.prestosql.spi.function.OperatorType; @@ -58,6 +61,7 @@ import io.prestosql.sql.tree.NodeRef; import io.prestosql.sql.tree.NotExpression; import io.prestosql.sql.tree.SymbolReference; +import io.prestosql.util.DisjointSet; import javax.annotation.Nullable; import javax.inject.Inject; @@ -66,16 +70,23 @@ import java.util.Map; import java.util.Optional; import java.util.OptionalDouble; +import java.util.Set; +import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.prestosql.SystemSessionProperties.getFilterConjunctionIndependenceFactor; import static io.prestosql.cost.ComparisonStatsCalculator.estimateExpressionToExpressionComparison; import static io.prestosql.cost.ComparisonStatsCalculator.estimateExpressionToLiteralComparison; import static io.prestosql.cost.PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues; import static io.prestosql.cost.PlanNodeStatsEstimateMath.capStats; +import static io.prestosql.cost.PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount; +import static io.prestosql.cost.PlanNodeStatsEstimateMath.intersectCorrelatedStats; import static io.prestosql.cost.PlanNodeStatsEstimateMath.subtractSubsetStats; import static io.prestosql.cost.StatsUtil.toStatsRepresentation; import static io.prestosql.expressions.LogicalRowExpressions.TRUE_CONSTANT; +import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.prestosql.spi.relation.SpecialForm.Form.IS_NULL; import static io.prestosql.spi.type.BooleanType.BOOLEAN; import static io.prestosql.sql.DynamicFilters.isDynamicFilter; @@ -83,6 +94,7 @@ import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.prestosql.sql.planner.RowExpressionInterpreter.Level.OPTIMIZED; import static io.prestosql.sql.planner.SymbolUtils.from; +import static io.prestosql.sql.planner.SymbolsExtractor.extractUnique; import static io.prestosql.sql.relational.Expressions.call; import static io.prestosql.sql.relational.Expressions.constantNull; import static io.prestosql.sql.tree.ComparisonExpression.Operator.EQUAL; @@ -210,7 +222,14 @@ private class FilterExpressionStatsCalculatingVisitor @Override public PlanNodeStatsEstimate process(Node node, @Nullable Void context) { - return normalizer.normalize(super.process(node, context), types); + PlanNodeStatsEstimate output; + if (input.getOutputRowCount() == 0 || input.isOutputRowCountUnknown()) { + output = input; + } + else { + output = super.process(node, context); + } + return normalizer.normalize(output, types); } @Override @@ -233,7 +252,7 @@ protected PlanNodeStatsEstimate visitLogicalBinaryExpression(LogicalBinaryExpres { switch (node.getOperator()) { case AND: - return estimateLogicalAnd(node.getLeft(), node.getRight()); + return estimateLogicalAnd(ImmutableList.of(node.getLeft(), node.getRight())); case OR: return estimateLogicalOr(node.getLeft(), node.getRight()); default: @@ -241,6 +260,61 @@ protected PlanNodeStatsEstimate visitLogicalBinaryExpression(LogicalBinaryExpres } } + private PlanNodeStatsEstimate estimateLogicalAnd(List terms) + { + double filterConjunctionIndependenceFactor = getFilterConjunctionIndependenceFactor(session); + List estimates = estimateCorrelatedExpressions(terms, filterConjunctionIndependenceFactor); + double outputRowCount = estimateCorrelatedConjunctionRowCount( + input, + estimates, + filterConjunctionIndependenceFactor); + if (isNaN(outputRowCount)) { + return PlanNodeStatsEstimate.unknown(); + } + return normalizer.normalize(new PlanNodeStatsEstimate(outputRowCount, intersectCorrelatedStats(estimates)), types); + } + + /** + * There can be multiple predicate expressions for the same symbol, e.g. x > 0 AND x <= 1, x BETWEEN 1 AND 10. + * We attempt to detect such cases in extractCorrelatedGroups and calculate a combined estimate for each + * such group of expressions. This is done so that we don't apply the above scaling factors when combining estimates + * from conjunction of multiple predicates on the same symbol and underestimate the output. + **/ + private List estimateCorrelatedExpressions(List terms, double filterConjunctionIndependenceFactor) + { + List> extractedCorrelatedGroups = extractCorrelatedGroups(terms, filterConjunctionIndependenceFactor); + ImmutableList.Builder estimatesBuilder = ImmutableList.builderWithExpectedSize(extractedCorrelatedGroups.size()); + boolean hasUnestimatedTerm = false; + for (List correlatedExpressions : extractedCorrelatedGroups) { + PlanNodeStatsEstimate combinedEstimate = PlanNodeStatsEstimate.unknown(); + for (Expression expression : correlatedExpressions) { + PlanNodeStatsEstimate estimate; + // combinedEstimate is unknown until the 1st known estimated term + if (combinedEstimate.isOutputRowCountUnknown()) { + estimate = process(expression); + } + else { + estimate = new FilterExpressionStatsCalculatingVisitor(combinedEstimate, session, types) + .process(expression); + } + + if (estimate.isOutputRowCountUnknown()) { + hasUnestimatedTerm = true; + } + else { + // update combinedEstimate only when the term estimate is known so that all the known estimates + // can be applied progressively through FilterExpressionStatsCalculatingVisitor calls. + combinedEstimate = estimate; + } + } + estimatesBuilder.add(combinedEstimate); + } + if (hasUnestimatedTerm) { + estimatesBuilder.add(PlanNodeStatsEstimate.unknown()); + } + return estimatesBuilder.build(); + } + private PlanNodeStatsEstimate estimateLogicalAnd(Expression left, Expression right) { // first try to estimate in the fair way @@ -935,4 +1009,53 @@ private SymbolStatsEstimate getRowExpressionStats(RowExpression expression) return scalarStatsCalculator.calculate(expression, input, session, layout); } } + + private static List> extractCorrelatedGroups(List terms, double filterConjunctionIndependenceFactor) + { + if (filterConjunctionIndependenceFactor == 1) { + // Allows the filters to be estimated as if there is no correlation between any of the terms + return ImmutableList.of(terms); + } + + ListMultimap expressionUniqueSymbols = ArrayListMultimap.create(); + terms.forEach(expression -> expressionUniqueSymbols.putAll(expression, extractUnique(expression))); + // Partition symbols into disjoint sets such that the symbols belonging to different disjoint sets + // do not appear together in any expression. + DisjointSet symbolsPartitioner = new DisjointSet<>(); + for (Expression term : terms) { + List expressionSymbols = expressionUniqueSymbols.get(term); + if (expressionSymbols.isEmpty()) { + continue; + } + // Ensure that symbol is added to DisjointSet when there is only one symbol in the list + symbolsPartitioner.find(expressionSymbols.get(0)); + for (int i = 1; i < expressionSymbols.size(); i++) { + symbolsPartitioner.findAndUnion(expressionSymbols.get(0), expressionSymbols.get(i)); + } + } + + // Use disjoint sets of symbols to partition the given list of expressions + List> symbolPartitions = ImmutableList.copyOf(symbolsPartitioner.getEquivalentClasses()); + checkState(symbolPartitions.size() <= terms.size(), "symbolPartitions size exceeds number of expressions"); + ListMultimap expressionPartitions = ArrayListMultimap.create(); + for (Expression term : terms) { + List expressionSymbols = expressionUniqueSymbols.get(term); + int expressionPartitionId; + if (expressionSymbols.isEmpty()) { + expressionPartitionId = symbolPartitions.size(); // For expressions with no symbols + } + else { + Symbol symbol = expressionSymbols.get(0); // Lookup any symbol to find the partition id + expressionPartitionId = IntStream.range(0, symbolPartitions.size()) + .filter(partition -> symbolPartitions.get(partition).contains(symbol)) + .findFirst() + .orElseThrow(() -> new PrestoException(GENERIC_INTERNAL_ERROR, "Requested symbol not found")); + } + expressionPartitions.put(expressionPartitionId, term); + } + + return expressionPartitions.keySet().stream() + .map(expressionPartitions::get) + .collect(toImmutableList()); + } } diff --git a/presto-main/src/main/java/io/prestosql/cost/JoinStatsRule.java b/presto-main/src/main/java/io/prestosql/cost/JoinStatsRule.java index 54b776266..33eb71fdd 100644 --- a/presto-main/src/main/java/io/prestosql/cost/JoinStatsRule.java +++ b/presto-main/src/main/java/io/prestosql/cost/JoinStatsRule.java @@ -28,16 +28,16 @@ import java.util.Collection; import java.util.HashMap; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Queue; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Sets.difference; +import static io.prestosql.SystemSessionProperties.getJoinMultiClauseIndependenceFactor; import static io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT; +import static io.prestosql.cost.PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount; import static io.prestosql.cost.SymbolStatsEstimate.buildFrom; import static io.prestosql.sql.ExpressionUtils.extractConjuncts; import static io.prestosql.sql.planner.SymbolUtils.toSymbolReference; @@ -204,23 +204,23 @@ private PlanNodeStatsEstimate filterByEquiJoinClauses( TypeProvider types) { checkArgument(!clauses.isEmpty(), "clauses is empty"); - PlanNodeStatsEstimate result = PlanNodeStatsEstimate.unknown(); - // Join equality clauses are usually correlated. Therefore we shouldn't treat each join equality - // clause separately because stats estimates would be way off. Instead we choose so called - // "driving clause" which mostly reduces join output rows cardinality and apply UNKNOWN_FILTER_COEFFICIENT - // for other (auxiliary) clauses. - Queue remainingClauses = new LinkedList<>(clauses); - EquiJoinClause drivingClause = remainingClauses.poll(); - for (int i = 0; i < clauses.size(); i++) { - PlanNodeStatsEstimate estimate = filterByEquiJoinClauses(stats, drivingClause, remainingClauses, session, types); - if (result.isOutputRowCountUnknown() || (!estimate.isOutputRowCountUnknown() && estimate.getOutputRowCount() < result.getOutputRowCount())) { - result = estimate; - } - remainingClauses.add(drivingClause); - drivingClause = remainingClauses.poll(); - } + // Join equality clauses are usually correlated. Therefore, we shouldn't treat each join equality + // clause separately because stats estimates would be way off. + List knownEstimates = clauses.stream() + .map(clause -> { + ComparisonExpression predicate = new ComparisonExpression(EQUAL, toSymbolReference(clause.getLeft()), toSymbolReference(clause.getRight())); + return new PlanNodeStatsEstimateWithClause(filterStatsCalculator.filterStats(stats, predicate, session, types), clause); + }) + .collect(toImmutableList()); - return result; + double outputRowCount = estimateCorrelatedConjunctionRowCount( + stats, + knownEstimates.stream().map(PlanNodeStatsEstimateWithClause::getEstimate).collect(toImmutableList()), + getJoinMultiClauseIndependenceFactor(session)); + if (isNaN(outputRowCount)) { + return PlanNodeStatsEstimate.unknown(); + } + return normalizer.normalize(new PlanNodeStatsEstimate(outputRowCount, intersectCorrelatedJoinClause(stats, knownEstimates)), types); } private PlanNodeStatsEstimate filterByEquiJoinClauses( @@ -445,4 +445,67 @@ private List flippedCriteria(JoinNode node) .map(EquiJoinClause::flip) .collect(toImmutableList()); } + + private static Map intersectCorrelatedJoinClause( + PlanNodeStatsEstimate stats, + List equiJoinClauseEstimates) + { + // Add initial statistics (including stats for columns which are not part of equi-join clauses) + PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder() + .addSymbolStatistics(stats.getSymbolStatistics()); + + for (PlanNodeStatsEstimateWithClause estimateWithClause : equiJoinClauseEstimates) { + EquiJoinClause clause = estimateWithClause.getClause(); + // we just clear null fraction and adjust ranges here, selectivity is handled outside this function + SymbolStatsEstimate leftStats = stats.getSymbolStatistics(clause.getLeft()); + SymbolStatsEstimate rightStats = stats.getSymbolStatistics(clause.getRight()); + StatisticRange leftRange = StatisticRange.from(leftStats); + StatisticRange rightRange = StatisticRange.from(rightStats); + + StatisticRange intersect = leftRange.intersect(rightRange); + double leftFilterValue = firstNonNaN(leftRange.overlapPercentWith(intersect), 1); + double rightFilterValue = firstNonNaN(rightRange.overlapPercentWith(intersect), 1); + double leftNdvInRange = leftFilterValue * leftRange.getDistinctValuesCount(); + double rightNdvInRange = rightFilterValue * rightRange.getDistinctValuesCount(); + double retainedNdv = MoreMath.min(leftNdvInRange, rightNdvInRange); + + SymbolStatsEstimate newLeftStats = buildFrom(leftStats) + .setNullsFraction(0) + .setStatisticsRange(intersect) + .setDistinctValuesCount(retainedNdv) + .build(); + + SymbolStatsEstimate newRightStats = buildFrom(rightStats) + .setNullsFraction(0) + .setStatisticsRange(intersect) + .setDistinctValuesCount(retainedNdv) + .build(); + + result.addSymbolStatistics(clause.getLeft(), newLeftStats) + .addSymbolStatistics(clause.getRight(), newRightStats); + } + return result.build().getSymbolStatistics(); + } + + private static class PlanNodeStatsEstimateWithClause + { + private final PlanNodeStatsEstimate estimate; + private final EquiJoinClause clause; + + private PlanNodeStatsEstimateWithClause(PlanNodeStatsEstimate estimate, EquiJoinClause clause) + { + this.estimate = requireNonNull(estimate, "estimate is null"); + this.clause = requireNonNull(clause, "clause is null"); + } + + private PlanNodeStatsEstimate getEstimate() + { + return estimate; + } + + private EquiJoinClause getClause() + { + return clause; + } + } } diff --git a/presto-main/src/main/java/io/prestosql/cost/PlanNodeStatsEstimateMath.java b/presto-main/src/main/java/io/prestosql/cost/PlanNodeStatsEstimateMath.java index 2d31457e8..00a9cf5b9 100644 --- a/presto-main/src/main/java/io/prestosql/cost/PlanNodeStatsEstimateMath.java +++ b/presto-main/src/main/java/io/prestosql/cost/PlanNodeStatsEstimateMath.java @@ -13,11 +13,22 @@ */ package io.prestosql.cost; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.plan.Symbol; +import io.prestosql.util.MoreMath; + +import java.util.List; +import java.util.Map; + import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT; +import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static java.lang.Double.NaN; import static java.lang.Double.isNaN; import static java.lang.Double.max; import static java.lang.Double.min; +import static java.util.Comparator.comparingDouble; import static java.util.stream.Stream.concat; public class PlanNodeStatsEstimateMath @@ -217,4 +228,81 @@ private static SymbolStatsEstimate addColumnStats(SymbolStatsEstimate leftStats, .setNullsFraction(newNullsFraction) .build(); } + + public static double estimateCorrelatedConjunctionRowCount( + PlanNodeStatsEstimate input, + List estimates, + double independenceFactor) + { + checkArgument(!estimates.isEmpty(), "estimates is empty"); + if (input.isOutputRowCountUnknown() || input.getOutputRowCount() == 0) { + return input.getOutputRowCount(); + } + List knownSortedEstimates = estimates.stream() + .filter(estimateInfo -> !estimateInfo.isOutputRowCountUnknown()) + .sorted(comparingDouble(PlanNodeStatsEstimate::getOutputRowCount)) + .collect(toImmutableList()); + if (knownSortedEstimates.isEmpty()) { + return NaN; + } + + PlanNodeStatsEstimate combinedEstimate = knownSortedEstimates.get(0); + double combinedSelectivity = combinedEstimate.getOutputRowCount() / input.getOutputRowCount(); + double combinedIndependenceFactor = 1.0; + // For independenceFactor = 0.75 and terms t1, t2, t3 + // Combined selectivity = (t1 selectivity) * ((t2 selectivity) ^ 0.75) * ((t3 selectivity) ^ (0.75 * 0.75)) + // independenceFactor = 1 implies the terms are assumed to have no correlation and their selectivities are multiplied without scaling. + // independenceFactor = 0 implies the terms are assumed to be fully correlated and only the most selective term drives the selectivity. + for (int i = 1; i < knownSortedEstimates.size(); i++) { + PlanNodeStatsEstimate term = knownSortedEstimates.get(i); + combinedIndependenceFactor *= independenceFactor; + combinedSelectivity *= Math.pow(term.getOutputRowCount() / input.getOutputRowCount(), combinedIndependenceFactor); + } + double outputRowCount = input.getOutputRowCount() * combinedSelectivity; + // TODO use UNKNOWN_FILTER_COEFFICIENT only when default-filter-factor is enabled + boolean hasUnestimatedTerm = estimates.stream().anyMatch(PlanNodeStatsEstimate::isOutputRowCountUnknown); + return hasUnestimatedTerm ? outputRowCount * UNKNOWN_FILTER_COEFFICIENT : outputRowCount; + } + + public static Map intersectCorrelatedStats(List estimates) + { + checkArgument(!estimates.isEmpty(), "estimates is empty"); + if (estimates.size() == 1) { + return estimates.get(0).getSymbolStatistics(); + } + PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder(); + // Update statistic range for symbols + estimates.stream().flatMap(estimate -> estimate.getSymbolsWithKnownStatistics().stream()) + .distinct() + .forEach(symbol -> { + List symbolStatsEstimates = estimates.stream() + .map(estimate -> estimate.getSymbolStatistics(symbol)) + .collect(toImmutableList()); + + StatisticRange intersect = symbolStatsEstimates.stream() + .map(StatisticRange::from) + .reduce(StatisticRange::intersect) + .orElseThrow(() -> new PrestoException(GENERIC_INTERNAL_ERROR, "StatisticRange is not present")); + + // intersectCorrelatedStats should try to produce stats as if filters are applied in sequence. + // Using min works for filters like (a > 10 AND b < 10), but won't work for + // (a > 10 AND b IS NULL). However, former case is more common. + double nullsFraction = symbolStatsEstimates.stream() + .map(SymbolStatsEstimate::getNullsFraction) + .reduce(MoreMath::minExcludeNaN) + .orElseThrow(() -> new PrestoException(GENERIC_INTERNAL_ERROR, "Nulls Fraction is not present")); + + double averageRowSize = symbolStatsEstimates.stream() + .map(SymbolStatsEstimate::getAverageRowSize) + .reduce(MoreMath::averageExcludingNaNs) + .orElseThrow(() -> new PrestoException(GENERIC_INTERNAL_ERROR, "Average Row Size is not present")); + + result.addSymbolStatistics(symbol, SymbolStatsEstimate.builder() + .setStatisticsRange(intersect) + .setNullsFraction(nullsFraction) + .setAverageRowSize(averageRowSize) + .build()); + }); + return result.build().getSymbolStatistics(); + } } diff --git a/presto-main/src/main/java/io/prestosql/cost/StatisticRange.java b/presto-main/src/main/java/io/prestosql/cost/StatisticRange.java index fa504a682..1175eea54 100644 --- a/presto-main/src/main/java/io/prestosql/cost/StatisticRange.java +++ b/presto-main/src/main/java/io/prestosql/cost/StatisticRange.java @@ -17,6 +17,8 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static io.prestosql.util.MoreMath.maxExcludeNaN; +import static io.prestosql.util.MoreMath.minExcludeNaN; import static java.lang.Double.NaN; import static java.lang.Double.isFinite; import static java.lang.Double.isInfinite; @@ -173,28 +175,6 @@ public StatisticRange addAndCollapseDistinctValues(StatisticRange other) return new StatisticRange(minExcludeNaN(low, other.low), maxExcludeNaN(high, other.high), newDistinctValues); } - private static double minExcludeNaN(double v1, double v2) - { - if (isNaN(v1)) { - return v2; - } - if (isNaN(v2)) { - return v1; - } - return min(v1, v2); - } - - private static double maxExcludeNaN(double v1, double v2) - { - if (isNaN(v1)) { - return v2; - } - if (isNaN(v2)) { - return v1; - } - return max(v1, v2); - } - @Override public boolean equals(Object o) { diff --git a/presto-main/src/main/java/io/prestosql/dynamicfilter/DynamicFilterService.java b/presto-main/src/main/java/io/prestosql/dynamicfilter/DynamicFilterService.java index 6feede54c..77557aa63 100644 --- a/presto-main/src/main/java/io/prestosql/dynamicfilter/DynamicFilterService.java +++ b/presto-main/src/main/java/io/prestosql/dynamicfilter/DynamicFilterService.java @@ -14,14 +14,17 @@ */ package io.prestosql.dynamicfilter; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; import io.airlift.log.Logger; import io.prestosql.Session; +import io.prestosql.execution.SqlQueryExecution; import io.prestosql.execution.StageStateMachine; import io.prestosql.execution.TaskId; import io.prestosql.metadata.InternalNode; +import io.prestosql.operator.JoinUtils; import io.prestosql.spi.PrestoException; import io.prestosql.spi.QueryId; import io.prestosql.spi.connector.ColumnHandle; @@ -42,6 +45,9 @@ import io.prestosql.spi.statestore.StateStore; import io.prestosql.spi.util.BloomFilter; import io.prestosql.sql.DynamicFilters; +import io.prestosql.sql.planner.PlanFragment; +import io.prestosql.sql.planner.SubPlan; +import io.prestosql.sql.planner.optimizations.PlanNodeSearcher; import io.prestosql.sql.planner.plan.SemiJoinNode; import io.prestosql.statestore.StateStoreProvider; import io.prestosql.utils.DynamicFilterUtils; @@ -74,6 +80,10 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Sets.difference; +import static com.google.common.collect.Sets.intersection; +import static com.google.common.collect.Sets.union; import static io.airlift.concurrent.Threads.threadsNamed; import static io.prestosql.SystemSessionProperties.getDynamicFilteringDataType; import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; @@ -84,6 +94,9 @@ import static io.prestosql.spi.dynamicfilter.DynamicFilter.Type.LOCAL; import static io.prestosql.spi.statestore.StateCollection.Type.MAP; import static io.prestosql.spi.statestore.StateCollection.Type.SET; +import static io.prestosql.sql.DynamicFilters.extractDynamicFilters; +import static io.prestosql.sql.planner.ExpressionExtractor.extractExpressions; +import static io.prestosql.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static io.prestosql.utils.DynamicFilterUtils.createKey; import static io.prestosql.utils.DynamicFilterUtils.findFilterNodeInStage; import static io.prestosql.utils.DynamicFilterUtils.getDynamicFilterDataType; @@ -101,6 +114,7 @@ public class DynamicFilterService private final Map> dynamicFiltersToTask = new ConcurrentHashMap<>(); private static final Map> cachedDynamicFilters = new HashMap<>(); private final List finishedQuery = Collections.synchronizedList(new ArrayList<>()); + private List registeredQueries = new ArrayList<>(); private final StateStoreProvider stateStoreProvider; @@ -247,6 +261,7 @@ private void removeFinishedQuery() dynamicFilters.remove(queryId); cachedDynamicFilters.remove(queryId); + registeredQueries.remove(queryId); handledQuery.add(queryId); } finishedQuery.removeAll(handledQuery); @@ -507,6 +522,110 @@ private static DynamicFilterRegistryInfo extractDynamicFilterRegistryInfo(SemiJo } } + public void registerQuery(SqlQueryExecution sqlQueryExecution, SubPlan fragmentedPlan) + { + PlanNode queryPlan = sqlQueryExecution.getQueryPlan().getRoot(); + Set dynamicFilters = getProducedDynamicFilters(queryPlan); + + if (!dynamicFilters.isEmpty()) { + registeredQueries.add(sqlQueryExecution.getQueryId()); + } + } + + /** + * Dynamic filters are collected in same stage as the join operator in pipelined execution. This can result in deadlock + * for source stage joins and connectors that wait for dynamic filters before generating splits + * (probe splits might be blocked on dynamic filters which require at least one probe task in order to be collected). + * To overcome this issue an initial task is created for source stages running broadcast join operator. + * This task allows for dynamic filters collection without any probe side splits being scheduled. + */ + public boolean isCollectingTaskNeeded(QueryId queryId, PlanFragment plan) + { + if (!registeredQueries.contains(queryId)) { + // query has been removed or not registered (e.g. dynamic filtering is disabled) + return false; + } + + // dynamic filters are collected by additional task only for non-fixed source stage + return plan.getPartitioning().equals(SOURCE_DISTRIBUTION) && !getLazyDynamicFilters(plan).isEmpty(); + } + + public boolean isStageSchedulingNeededToCollectDynamicFilters(QueryId queryId, PlanFragment plan) + { + if (!registeredQueries.contains(queryId)) { + // query has been removed or not registered (e.g. dynamic filtering is disabled) + return false; + } + + // stage scheduling is not needed to collect dynamic filters for non-fixed source stage, because + // for such stage collecting task is created + return !plan.getPartitioning().equals(SOURCE_DISTRIBUTION) && !getLazyDynamicFilters(plan).isEmpty(); + } + + private static Set getLazyDynamicFilters(PlanFragment plan) + { + // To prevent deadlock dynamic filter can be lazy only when: + // 1. it's consumed by different stage from where it's produced + // 2. or it's produced by replicated join in source stage. In such case an extra + // task is created that will collect dynamic filter and prevent deadlock. + Set interStageDynamicFilters = difference(getProducedDynamicFilters(plan.getRoot()), getConsumedDynamicFilters(plan.getRoot())); + return ImmutableSet.copyOf(union(interStageDynamicFilters, getSourceStageInnerLazyDynamicFilters(plan))); + } + + @VisibleForTesting + static Set getSourceStageInnerLazyDynamicFilters(PlanFragment plan) + { + if (!plan.getPartitioning().equals(SOURCE_DISTRIBUTION)) { + // Only non-fixed source stages can have (replicated) lazy dynamic filters that are + // produced and consumed within stage. This is because for such stages an extra + // dynamic filtering collecting task can be added. + return ImmutableSet.of(); + } + + PlanNode planNode = plan.getRoot(); + Set innerStageDynamicFilters = intersection(getProducedDynamicFilters(planNode), getConsumedDynamicFilters(planNode)); + Set replicatedDynamicFilters = getReplicatedDynamicFilters(planNode); + return ImmutableSet.copyOf(intersection(innerStageDynamicFilters, replicatedDynamicFilters)); + } + + private static Set getReplicatedDynamicFilters(PlanNode planNode) + { + return PlanNodeSearcher.searchFrom(planNode) + .whereIsInstanceOfAny(JoinNode.class, SemiJoinNode.class) + .findAll().stream() + .filter(JoinUtils::isBuildSideReplicated) + .flatMap(node -> getDynamicFiltersProducedInPlanNode(node).stream()) + .collect(toImmutableSet()); + } + + private static Set getProducedDynamicFilters(PlanNode planNode) + { + return PlanNodeSearcher.searchFrom(planNode) + .whereIsInstanceOfAny(JoinNode.class, SemiJoinNode.class) + .findAll().stream() + .flatMap(node -> getDynamicFiltersProducedInPlanNode(node).stream()) + .collect(toImmutableSet()); + } + + private static Set getConsumedDynamicFilters(PlanNode planNode) + { + return extractExpressions(planNode).stream() + .flatMap(expression -> extractDynamicFilters(expression).getDynamicConjuncts().stream()) + .map(DynamicFilters.Descriptor::getId) + .collect(toImmutableSet()); + } + + private static Set getDynamicFiltersProducedInPlanNode(PlanNode planNode) + { + if (planNode instanceof JoinNode) { + return ((JoinNode) planNode).getDynamicFilters().keySet(); + } + if (planNode instanceof SemiJoinNode) { + return ((SemiJoinNode) planNode).getDynamicFilterId().map(ImmutableSet::of).orElse(ImmutableSet.of()); + } + throw new IllegalStateException("getDynamicFiltersProducedInPlanNode called with neither JoinNode nor SemiJoinNode"); + } + private static class DynamicFilterRegistryInfo { private final Symbol symbol; diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java b/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java index b9da19b5e..f9fe11798 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java @@ -34,7 +34,6 @@ import io.prestosql.execution.StateMachine.StateChangeListener; import io.prestosql.execution.buffer.OutputBuffers; import io.prestosql.execution.buffer.OutputBuffers.OutputBufferId; -import io.prestosql.execution.scheduler.ExecutionPolicy; import io.prestosql.execution.scheduler.NodeAllocatorService; import io.prestosql.execution.scheduler.NodeScheduler; import io.prestosql.execution.scheduler.PartitionMemoryEstimatorFactory; @@ -43,6 +42,7 @@ import io.prestosql.execution.scheduler.TaskDescriptorStorage; import io.prestosql.execution.scheduler.TaskExecutionStats; import io.prestosql.execution.scheduler.TaskSourceFactory; +import io.prestosql.execution.scheduler.policy.ExecutionPolicy; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.failuredetector.FailureDetector; import io.prestosql.heuristicindex.HeuristicIndexerManager; @@ -606,6 +606,7 @@ public void start() plan = analyzeQuery(); try { + registerDynamicFilteringQuery(plan); handleCrossRegionDynamicFilter(plan); } catch (Throwable e) { @@ -944,6 +945,20 @@ private static Set extractConnectors(Analysis analysis) return connectors.build(); } + private synchronized void registerDynamicFilteringQuery(PlanRoot plan) + { + if (!isEnableDynamicFiltering(stateMachine.getSession())) { + return; + } + + if (isDone()) { + // query has finished or was cancelled asynchronously + return; + } + + dynamicFilterService.registerQuery(this, plan.getRoot()); + } + private void planDistribution(PlanRoot plan) { // time distribution planning diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java b/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java index 7c8561d3b..ad847703e 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java @@ -31,6 +31,7 @@ import io.prestosql.exchange.ExchangeSinkInstanceHandle; import io.prestosql.execution.StateMachine.StateChangeListener; import io.prestosql.execution.buffer.OutputBuffers; +import io.prestosql.execution.scheduler.PartitionIdAllocator; import io.prestosql.execution.scheduler.SplitSchedulerStats; import io.prestosql.failuredetector.FailureDetector; import io.prestosql.metadata.InternalNode; @@ -355,6 +356,18 @@ public synchronized void suspend() getAllTasks().forEach(RemoteTask::suspend); } + public List getTaskStatuses() + { + return getAllTasks().stream() + .map(RemoteTask::getTaskStatus) + .collect(toImmutableList()); + } + + public boolean isAnyTaskBlocked() + { + return getTaskStatuses().stream().anyMatch(TaskStatus::isOutputBufferOverutilized); + } + public synchronized void resume() { stateMachine.transitionToRunning(); @@ -571,7 +584,18 @@ public synchronized Optional scheduleTask(InternalNode node, int par return Optional.of(scheduleTask(node, new TaskId(stateMachine.getStageId(), partition, 0), generateInstanceId(), ImmutableMultimap.of(), totalPartitions)); } - public synchronized Set scheduleSplits(InternalNode node, Multimap splits, Multimap noMoreSplitsNotification) + public synchronized Optional scheduleTask(InternalNode node, int partition, Multimap initialSplits) + { + requireNonNull(node, "node is null"); + + if (stateMachine.getState().isDone()) { + return Optional.empty(); + } + checkState(!splitsScheduled.get(), "scheduleTask can not be called once splits have been scheduled"); + return Optional.of(scheduleTask(node, new TaskId(stateMachine.getStageId(), partition, 0), generateInstanceId(), initialSplits, OptionalInt.empty())); + } + + public synchronized Set scheduleSplits(InternalNode node, Multimap splits, Multimap noMoreSplitsNotification, PartitionIdAllocator partitionIdAllocator) { requireNonNull(node, "node is null"); requireNonNull(splits, "splits is null"); @@ -589,7 +613,7 @@ public synchronized Set scheduleSplits(InternalNode node, Multimap

undeclaredCreatedBuffers = Sets.difference(buffers.keySet(), outputBuffers.getBuffers().keySet()); - checkState(undeclaredCreatedBuffers.isEmpty(), "Final output buffers does not contain all created buffer ids: %s", undeclaredCreatedBuffers); + checkState(undeclaredCreatedBuffers.isEmpty(), "Final output buffers does not contain all created buffer ids: %s [buffers: %s] [outputBuffers: %s] ", undeclaredCreatedBuffers, + buffers.keySet(), outputBuffers.getBuffers().keySet()); } } diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/FixedSourcePartitionedScheduler.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/FixedSourcePartitionedScheduler.java index 5b581c543..41de1d8fc 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/FixedSourcePartitionedScheduler.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/FixedSourcePartitionedScheduler.java @@ -15,10 +15,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Streams; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.log.Logger; import io.prestosql.Session; +import io.prestosql.dynamicfilter.DynamicFilterService; import io.prestosql.execution.Lifespan; import io.prestosql.execution.RemoteTask; import io.prestosql.execution.SqlStageExecution; @@ -36,6 +36,7 @@ import io.prestosql.split.SplitSource; import java.util.ArrayList; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -47,11 +48,9 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.concurrent.MoreFutures.whenAnyComplete; import static io.prestosql.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsSourceScheduler; import static io.prestosql.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; -import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class FixedSourcePartitionedScheduler @@ -63,8 +62,10 @@ public class FixedSourcePartitionedScheduler private final List nodes; private final List sourceSchedulers; private final List partitionHandles; - private boolean scheduledTasks; + private final Map scheduledTasks; private final Optional groupedLifespanScheduler; + private final PartitionIdAllocator partitionIdAllocator; + private final DynamicFilterService dynamicFilterService; public FixedSourcePartitionedScheduler( SqlStageExecution stage, @@ -79,7 +80,8 @@ public FixedSourcePartitionedScheduler( List partitionHandles, Session session, HeuristicIndexerManager heuristicIndexerManager, - TableExecuteContextManager tableExecuteContextManager) + TableExecuteContextManager tableExecuteContextManager, + DynamicFilterService dynamicFilterService) { requireNonNull(stage, "stage is null"); requireNonNull(splitSources, "splitSources is null"); @@ -110,6 +112,8 @@ public FixedSourcePartitionedScheduler( } boolean firstPlanNode = true; + partitionIdAllocator = new PartitionIdAllocator(); + scheduledTasks = new HashMap<>(); Optional groupedLifespanSchedulerOptional = Optional.empty(); for (PlanNodeId planNodeId : schedulingOrder) { SplitSource splitSource = splitSources.get(planNodeId); @@ -124,7 +128,10 @@ public FixedSourcePartitionedScheduler( groupedExecutionForScanNode, session, heuristicIndexerManager, - tableExecuteContextManager); + tableExecuteContextManager, + partitionIdAllocator, + scheduledTasks, + dynamicFilterService); if (stageExecutionDescriptor.isStageGroupedExecution() && !groupedExecutionForScanNode) { sourceScheduler = new AsGroupedSourceScheduler(sourceScheduler); @@ -161,6 +168,7 @@ public FixedSourcePartitionedScheduler( } this.groupedLifespanScheduler = groupedLifespanSchedulerOptional; this.sourceSchedulers = sourceSchedulerArrayList; + this.dynamicFilterService = dynamicFilterService; } private ConnectorPartitionHandle partitionHandleFor(Lifespan lifespan) @@ -182,15 +190,17 @@ public ScheduleResult schedule(int maxSplitGroup) { // schedule a task on every node in the distribution List newTasks = ImmutableList.of(); - if (!scheduledTasks) { + if (scheduledTasks.isEmpty()) { OptionalInt totalPartitions = OptionalInt.of(nodes.size()); - newTasks = Streams.mapWithIndex( - nodes.stream(), - (node, id) -> stage.scheduleTask(node, toIntExact(id), totalPartitions)) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(toImmutableList()); - scheduledTasks = true; + ImmutableList.Builder newTasksBuilder = ImmutableList.builder(); + for (InternalNode node : nodes) { + Optional task = stage.scheduleTask(node, partitionIdAllocator.getNextId(), totalPartitions); + if (task.isPresent()) { + scheduledTasks.put(node, task.get()); + newTasksBuilder.add(task.get()); + } + } + newTasks = newTasksBuilder.build(); } boolean allBlocked = true; @@ -323,6 +333,12 @@ public AsGroupedSourceScheduler(SourceScheduler sourceScheduler) pendingCompleted = new ArrayList<>(); } + @Override + public Optional start() + { + return sourceScheduler.start(); + } + @Override public ScheduleResult schedule() { diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/PartitionIdAllocator.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/PartitionIdAllocator.java new file mode 100644 index 000000000..54edab399 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/PartitionIdAllocator.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution.scheduler; + +import java.util.concurrent.atomic.AtomicInteger; + +public class PartitionIdAllocator +{ + private final AtomicInteger nextId = new AtomicInteger(); + + public int getNextId() + { + return nextId.getAndIncrement(); + } +} diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/SourcePartitionedScheduler.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/SourcePartitionedScheduler.java index 2df76168a..1e7e041cf 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/SourcePartitionedScheduler.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/SourcePartitionedScheduler.java @@ -24,6 +24,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.prestosql.Session; import io.prestosql.SystemSessionProperties; +import io.prestosql.dynamicfilter.DynamicFilterService; import io.prestosql.execution.Lifespan; import io.prestosql.execution.RemoteTask; import io.prestosql.execution.SqlStageExecution; @@ -54,6 +55,7 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -117,6 +119,9 @@ private enum State private SettableFuture whenFinishedOrNewLifespanAdded = SettableFuture.create(); private int throttledSplitsCount; + private PartitionIdAllocator partitionIdAllocator; + private final Map scheduledTasks; + private final DynamicFilterService dynamicFilterService; private SourcePartitionedScheduler( SqlStageExecution stage, @@ -127,7 +132,10 @@ private SourcePartitionedScheduler( boolean groupedExecution, Session session, HeuristicIndexerManager heuristicIndexerManager, - TableExecuteContextManager tableExecuteContextManager) + TableExecuteContextManager tableExecuteContextManager, + PartitionIdAllocator partitionIdAllocator, + Map scheduledTasks, + DynamicFilterService dynamicFilterService) { this.stage = requireNonNull(stage, "stage is null"); this.partitionedNode = requireNonNull(partitionedNode, "partitionedNode is null"); @@ -140,6 +148,9 @@ private SourcePartitionedScheduler( this.splitBatchSize = splitBatchSize; this.groupedExecution = groupedExecution; this.throttledSplitsCount = 0; + this.partitionIdAllocator = partitionIdAllocator; + this.scheduledTasks = scheduledTasks; + this.dynamicFilterService = dynamicFilterService; } @Override @@ -163,15 +174,22 @@ public static StageScheduler newSourcePartitionedSchedulerAsStageScheduler( int splitBatchSize, Session session, HeuristicIndexerManager heuristicIndexerManager, - TableExecuteContextManager tableExecuteContextManager) + TableExecuteContextManager tableExecuteContextManager, + DynamicFilterService dynamicFilterService) { SourcePartitionedScheduler sourcePartitionedScheduler = new SourcePartitionedScheduler(stage, partitionedNode, splitSource, - splitPlacementPolicy, splitBatchSize, false, session, heuristicIndexerManager, tableExecuteContextManager); + splitPlacementPolicy, splitBatchSize, false, session, heuristicIndexerManager, tableExecuteContextManager, new PartitionIdAllocator(), new HashMap<>(), dynamicFilterService); sourcePartitionedScheduler.startLifespan(Lifespan.taskWide(), NOT_PARTITIONED); sourcePartitionedScheduler.noMoreLifespans(); return new StageScheduler() { + @Override + public Optional start() + { + return sourcePartitionedScheduler.start(); + } + @Override public ScheduleResult schedule() { @@ -214,10 +232,13 @@ public static SourceScheduler newSourcePartitionedSchedulerAsSourceScheduler( boolean groupedExecution, Session session, HeuristicIndexerManager heuristicIndexerManager, - TableExecuteContextManager tableExecuteContextManager) + TableExecuteContextManager tableExecuteContextManager, + PartitionIdAllocator partitionIdAllocator, + Map scheduledTasks, + DynamicFilterService dynamicFilterService) { return new SourcePartitionedScheduler(stage, partitionedNode, splitSource, splitPlacementPolicy, - splitBatchSize, groupedExecution, session, heuristicIndexerManager, tableExecuteContextManager); + splitBatchSize, groupedExecution, session, heuristicIndexerManager, tableExecuteContextManager, partitionIdAllocator, scheduledTasks, dynamicFilterService); } @Override @@ -241,6 +262,37 @@ public synchronized void noMoreLifespans() whenFinishedOrNewLifespanAdded = SettableFuture.create(); } + @Override + public Optional start() + { + // Avoid deadlocks by immediately scheduling a task for collecting dynamic filters because: + // * there can be task in other stage blocked waiting for the dynamic filters, or + // * connector split source for this stage might be blocked waiting the dynamic filters. + if (dynamicFilterService.isCollectingTaskNeeded(stage.getStageId().getQueryId(), stage.getFragment())) { + stage.beginScheduling(); + return createTaskOnRandomNode(); + } + else { + return Optional.empty(); + } + } + + private Optional createTaskOnRandomNode() + { + checkState(scheduledTasks.isEmpty(), "Stage task is already scheduled on node"); + List allNodes = splitPlacementPolicy.allNodes(); + checkState(allNodes.size() > 0, "No nodes available"); + InternalNode node = allNodes.get(ThreadLocalRandom.current().nextInt(0, allNodes.size())); + return scheduleTask(node, ImmutableMultimap.of()); + } + + private Optional scheduleTask(InternalNode node, Multimap initialSplits) + { + Optional remoteTask = stage.scheduleTask(node, partitionIdAllocator.getNextId(), initialSplits); + remoteTask.ifPresent(task -> scheduledTasks.put(node, task)); + return remoteTask; + } + @Override public synchronized ScheduleResult schedule() { @@ -573,7 +625,8 @@ private Set assignSplits(Multimap splitAssignme newTasks.addAll(stage.scheduleSplits( node, splits, - noMoreSplits.build())); + noMoreSplits.build(), + partitionIdAllocator)); } return newTasks.build(); } @@ -590,7 +643,7 @@ private Set finalizeTaskCreationIfNecessary() Set scheduledNodes = stage.getScheduledNodes(); Set newTasks = splitPlacementPolicy.allNodes().stream() .filter(node -> !scheduledNodes.contains(node)) - .flatMap(node -> stage.scheduleSplits(node, ImmutableMultimap.of(), ImmutableMultimap.of()).stream()) + .flatMap(node -> stage.scheduleSplits(node, ImmutableMultimap.of(), ImmutableMultimap.of(), partitionIdAllocator).stream()) .collect(toImmutableSet()); // notify listeners that we have scheduled all tasks so they can set no more buffers or exchange splits diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/SourceScheduler.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/SourceScheduler.java index 92ceda528..8d99c1752 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/SourceScheduler.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/SourceScheduler.java @@ -15,13 +15,17 @@ package io.prestosql.execution.scheduler; import io.prestosql.execution.Lifespan; +import io.prestosql.execution.RemoteTask; import io.prestosql.spi.connector.ConnectorPartitionHandle; import io.prestosql.spi.plan.PlanNodeId; import java.util.List; +import java.util.Optional; public interface SourceScheduler { + Optional start(); + ScheduleResult schedule(); default ScheduleResult schedule(int maxSplitGroupSize) diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java index 93613338e..621438ed9 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java @@ -57,6 +57,9 @@ import io.prestosql.execution.TaskStatus; import io.prestosql.execution.buffer.OutputBuffers; import io.prestosql.execution.buffer.OutputBuffers.OutputBufferId; +import io.prestosql.execution.scheduler.policy.ExecutionPolicy; +import io.prestosql.execution.scheduler.policy.ExecutionSchedule; +import io.prestosql.execution.scheduler.policy.StagesScheduleResult; import io.prestosql.failuredetector.FailureDetector; import io.prestosql.heuristicindex.HeuristicIndexerManager; import io.prestosql.metadata.InternalNode; @@ -676,7 +679,7 @@ private List createStages( checkArgument(!plan.getFragment().getStageExecutionDescriptor().isStageGroupedExecution()); stageSchedulers.put(stageId, newSourcePartitionedSchedulerAsStageScheduler(stageExecution, planNodeId, splitSource, - placementPolicy, splitBatchSize, session, heuristicIndexerManager, tableExecuteContextManager)); + placementPolicy, splitBatchSize, session, heuristicIndexerManager, tableExecuteContextManager, dynamicFilterService)); bucketToPartition = Optional.of(new int[1]); } @@ -759,7 +762,8 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { connectorPartitionHandles, session, heuristicIndexerManager, - tableExecuteContextManager)); + tableExecuteContextManager, + dynamicFilterService)); } else { // all sources are remote @@ -1028,9 +1032,12 @@ private void schedule() try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { Set completedStages = new HashSet<>(); ExecutionSchedule executionSchedule = executionPolicy.createExecutionSchedule(stages.values()); + stageSchedulers.values().stream().map(stageScheduler -> stageScheduler.start()).filter(Optional::isPresent) + .forEach(task -> stageLinkages.get(task.get().getTaskId().getStageId()).processScheduleResults(stages.get(task.get().getTaskId().getStageId()).getState(), ImmutableSet.of(task.get()))); while (!executionSchedule.isFinished()) { List> blockedStages = new ArrayList<>(); - for (SqlStageExecution stage : executionSchedule.getStagesToSchedule()) { + StagesScheduleResult stagesScheduleResult = executionSchedule.getStagesToSchedule(); + for (SqlStageExecution stage : stagesScheduleResult.getStagesToSchedule()) { if (isReuseTableScanEnabled(session) && !SqlStageExecution.getReuseTableScanMappingIdStatus(stage.getStateMachine())) { continue; } @@ -1111,6 +1118,11 @@ else if (!result.getBlocked().isDone()) { // wait for a state change and then schedule again if (!blockedStages.isEmpty()) { + ImmutableList.Builder> futures = ImmutableList.builder(); + futures.addAll(blockedStages); + // allow for schedule to resume scheduling (e.g. when some active stage completes + // and dependent stages can be started) + stagesScheduleResult.getRescheduleFuture().ifPresent(futures::add); try (TimeStat.BlockTimer timer = schedulerStats.getSleepTime().time()) { tryGetFutureValue(whenAnyComplete(blockedStages), 1, SECONDS); } diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/StageScheduler.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/StageScheduler.java index 55500a352..38026f6e5 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/StageScheduler.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/StageScheduler.java @@ -13,11 +13,28 @@ */ package io.prestosql.execution.scheduler; +import io.prestosql.execution.RemoteTask; + import java.io.Closeable; +import java.util.Optional; public interface StageScheduler extends Closeable { + /** + * Called by the query scheduler when the scheduling process begins. + * This method is called before the ExecutionSchedule takes a decision + * to schedule a stage but after the query scheduling has been fully initialized. + * Within this method the scheduler may decide to schedule tasks that + * are necessary for query execution to make progress. + * For example the scheduler may decide to schedule a task without + * assigning any splits to unblock dynamic filter collection. + */ + default Optional start() + { + return Optional.empty(); + } + /** * Schedules as much work as possible without blocking. * The schedule results is a hint to the query scheduler if and diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/AllAtOnceExecutionPolicy.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/AllAtOnceExecutionPolicy.java similarity index 94% rename from presto-main/src/main/java/io/prestosql/execution/scheduler/AllAtOnceExecutionPolicy.java rename to presto-main/src/main/java/io/prestosql/execution/scheduler/policy/AllAtOnceExecutionPolicy.java index 61899b9af..d301a96b0 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/AllAtOnceExecutionPolicy.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/AllAtOnceExecutionPolicy.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.prestosql.execution.scheduler; +package io.prestosql.execution.scheduler.policy; import io.prestosql.execution.SqlStageExecution; diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/AllAtOnceExecutionSchedule.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/AllAtOnceExecutionSchedule.java similarity index 97% rename from presto-main/src/main/java/io/prestosql/execution/scheduler/AllAtOnceExecutionSchedule.java rename to presto-main/src/main/java/io/prestosql/execution/scheduler/policy/AllAtOnceExecutionSchedule.java index 4e1c2d62f..c9e504ad5 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/AllAtOnceExecutionSchedule.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/AllAtOnceExecutionSchedule.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.prestosql.execution.scheduler; +package io.prestosql.execution.scheduler.policy; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; @@ -67,7 +67,7 @@ public AllAtOnceExecutionSchedule(Collection stages) } @Override - public Set getStagesToSchedule() + public StagesScheduleResult getStagesToSchedule() { for (Iterator iterator = schedulingStages.iterator(); iterator.hasNext(); ) { StageState state = iterator.next().getState(); @@ -75,7 +75,7 @@ public Set getStagesToSchedule() iterator.remove(); } } - return schedulingStages; + return new StagesScheduleResult(schedulingStages); } @Override diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/ExecutionPolicy.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/ExecutionPolicy.java similarity index 93% rename from presto-main/src/main/java/io/prestosql/execution/scheduler/ExecutionPolicy.java rename to presto-main/src/main/java/io/prestosql/execution/scheduler/policy/ExecutionPolicy.java index 247407cbd..01dbb641b 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/ExecutionPolicy.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/ExecutionPolicy.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.prestosql.execution.scheduler; +package io.prestosql.execution.scheduler.policy; import io.prestosql.execution.SqlStageExecution; diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/ExecutionSchedule.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/ExecutionSchedule.java similarity index 79% rename from presto-main/src/main/java/io/prestosql/execution/scheduler/ExecutionSchedule.java rename to presto-main/src/main/java/io/prestosql/execution/scheduler/policy/ExecutionSchedule.java index bb49c3870..57dec0490 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/ExecutionSchedule.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/ExecutionSchedule.java @@ -11,15 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.prestosql.execution.scheduler; - -import io.prestosql.execution.SqlStageExecution; - -import java.util.Set; +package io.prestosql.execution.scheduler.policy; public interface ExecutionSchedule { - Set getStagesToSchedule(); + StagesScheduleResult getStagesToSchedule(); boolean isFinished(); } diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/PhasedExecutionPolicy.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PhasedExecutionPolicy.java similarity index 94% rename from presto-main/src/main/java/io/prestosql/execution/scheduler/PhasedExecutionPolicy.java rename to presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PhasedExecutionPolicy.java index c0eb847c3..b75fb4576 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/PhasedExecutionPolicy.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PhasedExecutionPolicy.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.prestosql.execution.scheduler; +package io.prestosql.execution.scheduler.policy; import io.prestosql.execution.SqlStageExecution; diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/PhasedExecutionSchedule.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PhasedExecutionSchedule.java similarity index 98% rename from presto-main/src/main/java/io/prestosql/execution/scheduler/PhasedExecutionSchedule.java rename to presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PhasedExecutionSchedule.java index d80ebbcdd..914d44402 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/PhasedExecutionSchedule.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PhasedExecutionSchedule.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.prestosql.execution.scheduler; +package io.prestosql.execution.scheduler.policy; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; @@ -78,14 +78,14 @@ public PhasedExecutionSchedule(Collection stages) } @Override - public Set getStagesToSchedule() + public StagesScheduleResult getStagesToSchedule() { removeCompletedStages(); addPhasesIfNecessary(); if (isFinished()) { - return ImmutableSet.of(); + return new StagesScheduleResult(ImmutableSet.of()); } - return activeSources; + return new StagesScheduleResult(activeSources); } private void removeCompletedStages() diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PrioritizeUtilizationExecutionPolicy.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PrioritizeUtilizationExecutionPolicy.java new file mode 100644 index 000000000..8969ba210 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PrioritizeUtilizationExecutionPolicy.java @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution.scheduler.policy; + +import io.prestosql.dynamicfilter.DynamicFilterService; +import io.prestosql.execution.SqlStageExecution; + +import javax.inject.Inject; + +import java.util.Collection; + +import static java.util.Objects.requireNonNull; + +public class PrioritizeUtilizationExecutionPolicy + implements ExecutionPolicy +{ + private final DynamicFilterService dynamicFilterService; + + @Inject + public PrioritizeUtilizationExecutionPolicy(DynamicFilterService dynamicFilterService) + { + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + } + + @Override + public ExecutionSchedule createExecutionSchedule(Collection stages) + { + return PrioritizeUtilizationExecutionSchedule.forStages(stages, dynamicFilterService); + } +} diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PrioritizeUtilizationExecutionSchedule.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PrioritizeUtilizationExecutionSchedule.java new file mode 100644 index 000000000..57ba128e3 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/PrioritizeUtilizationExecutionSchedule.java @@ -0,0 +1,608 @@ +/* + * Copyright (C) 2018-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution.scheduler.policy; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Ordering; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.prestosql.dynamicfilter.DynamicFilterService; +import io.prestosql.execution.SqlStageExecution; +import io.prestosql.execution.StageState; +import io.prestosql.spi.PrestoException; +import io.prestosql.spi.QueryId; +import io.prestosql.spi.plan.AggregationNode; +import io.prestosql.spi.plan.JoinNode; +import io.prestosql.spi.plan.PlanNode; +import io.prestosql.sql.planner.PlanFragment; +import io.prestosql.sql.planner.plan.ExchangeNode; +import io.prestosql.sql.planner.plan.IndexJoinNode; +import io.prestosql.sql.planner.plan.InternalPlanVisitor; +import io.prestosql.sql.planner.plan.PlanFragmentId; +import io.prestosql.sql.planner.plan.RemoteSourceNode; +import io.prestosql.sql.planner.plan.SemiJoinNode; +import io.prestosql.sql.planner.plan.SpatialJoinNode; +import org.jgrapht.DirectedGraph; +import org.jgrapht.EdgeFactory; +import org.jgrapht.alg.StrongConnectivityInspector; +import org.jgrapht.graph.DefaultDirectedGraph; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.execution.StageState.PENDING; +import static io.prestosql.execution.StageState.RUNNING; +import static io.prestosql.execution.StageState.SCHEDULED; +import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.prestosql.spi.plan.AggregationNode.Step.FINAL; +import static io.prestosql.spi.plan.AggregationNode.Step.SINGLE; +import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +public class PrioritizeUtilizationExecutionSchedule + implements ExecutionSchedule +{ + /** + * Graph representing a before -> after relationship between fragments. + * Destination fragment should be started only when source stage is completed. + */ + private final DirectedGraph fragmentDependency; + /** + * Graph representing topology between fragments (e.g. child -> parent relationship). + */ + private final DirectedGraph fragmentTopology; + private final Map stagesByFragmentId; + private Ordering fragmentOrdering; + private final List sortedFragments = new ArrayList<>(); + private final Set activeStages = new HashSet<>(); + private final DynamicFilterService dynamicFilterService; + + @GuardedBy("this") + private SettableFuture rescheduleFuture = SettableFuture.create(); + + public static PrioritizeUtilizationExecutionSchedule forStages(Collection stages, DynamicFilterService dynamicFilterService) + { + PrioritizeUtilizationExecutionSchedule schedule = new PrioritizeUtilizationExecutionSchedule(stages, dynamicFilterService); + schedule.init(stages); + return schedule; + } + + private PrioritizeUtilizationExecutionSchedule(Collection stages, DynamicFilterService dynamicFilterService) + { + fragmentDependency = new DefaultDirectedGraph<>(new FragmentsEdgeFactory()); + fragmentTopology = new DefaultDirectedGraph<>(new FragmentsEdgeFactory()); + stagesByFragmentId = stages.stream() + .collect(toImmutableMap(stage -> stage.getFragment().getId(), identity())); + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + } + + private void init(Collection stages) + { + ImmutableSet.Builder fragmentsToExecute = ImmutableSet.builder(); + fragmentsToExecute.addAll(extractDependenciesAndReturnNonLazyFragments(stages)); + // start stages without any dependencies + fragmentDependency.vertexSet().stream() + .filter(fragmentId -> fragmentDependency.inDegreeOf(fragmentId) == 0) + .forEach(fragmentsToExecute::add); + fragmentOrdering = Ordering.explicit(sortedFragments); + selectForExecution(fragmentsToExecute.build()); + } + + @Override + public StagesScheduleResult getStagesToSchedule() + { + // obtain reschedule future before actual scheduling, so that state change + // notifications from previously started stages are not lost + Optional> rescheduleFuture = getRescheduleFuture(); + schedule(); + return new StagesScheduleResult(activeStages, rescheduleFuture); + } + + @Override + public boolean isFinished() + { + // dependency graph contains both running and not started fragments + return fragmentDependency.vertexSet().isEmpty(); + } + + @VisibleForTesting + synchronized Optional> getRescheduleFuture() + { + return Optional.of(rescheduleFuture); + } + + @VisibleForTesting + void schedule() + { + ImmutableSet.Builder fragmentsToExecute = new ImmutableSet.Builder<>(); + fragmentsToExecute.addAll(removeCompletedStages()); + fragmentsToExecute.addAll(unblockStagesWithFullOutputBuffer()); + selectForExecution(fragmentsToExecute.build()); + } + + @VisibleForTesting + DirectedGraph getFragmentDependency() + { + return fragmentDependency; + } + + @VisibleForTesting + Set getActiveStages() + { + return activeStages; + } + + private Set removeCompletedStages() + { + Set completedStages = activeStages.stream() + .filter(this::isStageCompleted) + .collect(toImmutableSet()); + // remove completed stages outside of Java stream to prevent concurrent modification + return completedStages.stream() + .flatMap(stage -> removeCompletedStage(stage).stream()) + .collect(toImmutableSet()); + } + + private Set removeCompletedStage(SqlStageExecution stage) + { + // start all stages that depend on completed stage + PlanFragmentId fragmentId = stage.getFragment().getId(); + Set fragmentsToExecute = fragmentDependency.outgoingEdgesOf(fragmentId).stream() + .map(FragmentsEdge::getTarget) + // filter stages that depend on completed stage only + .filter(dependentFragmentId -> fragmentDependency.inDegreeOf(dependentFragmentId) == 1) + .collect(toImmutableSet()); + fragmentDependency.removeVertex(fragmentId); + fragmentTopology.removeVertex(fragmentId); + activeStages.remove(stage); + + return fragmentsToExecute; + } + + private Set unblockStagesWithFullOutputBuffer() + { + // find stages that are blocked on full task output buffer + Set blockedFragments = activeStages.stream() + .filter(SqlStageExecution::isAnyTaskBlocked) + .map(stage -> stage.getFragment().getId()) + .collect(toImmutableSet()); + // start immediate downstream stages so that data can be consumed + Set immediateDownStreamFragments = blockedFragments.stream() + .flatMap(fragmentId -> fragmentTopology.outgoingEdgesOf(fragmentId).stream()) + .map(FragmentsEdge::getTarget) + .collect(Collectors.toSet()); + Set dependentDownStreamFragments = blockedFragments.stream() + .flatMap(fragmentId -> fragmentDependency.outgoingEdgesOf(fragmentId).stream()) + .map(FragmentsEdge::getTarget) + .collect(Collectors.toSet()); + Set fragmentsToExecute = new HashSet<>(); + if (immediateDownStreamFragments.isEmpty()) { + fragmentsToExecute.addAll(dependentDownStreamFragments); + } + else { + fragmentsToExecute.addAll(immediateDownStreamFragments); + } + return fragmentsToExecute; + } + + private void selectForExecution(Set fragmentIds) + { + requireNonNull(fragmentOrdering, "fragmentOrdering is null"); + fragmentIds.stream() + .sorted(fragmentOrdering) + .map(stagesByFragmentId::get) + .forEach(this::selectForExecution); + } + + private void selectForExecution(SqlStageExecution stage) + { + if (isStageCompleted(stage)) { + // don't start completed stages (can happen when non-lazy stage is selected for + // execution and stage is started immediately even with dependencies) + return; + } + + if (fragmentDependency.outDegreeOf(stage.getFragment().getId()) > 0) { + // if there are any dependent stages then reschedule when stage is completed + stage.addStateChangeListener(state -> { + if (isStageCompleted(stage)) { + notifyReschedule(); + } + }); + } + activeStages.add(stage); + } + + private void notifyReschedule() + { + SettableFuture rescheduleFuture; + synchronized (this) { + rescheduleFuture = this.rescheduleFuture; + this.rescheduleFuture = SettableFuture.create(); + } + // notify listeners outside the critical section + rescheduleFuture.set(null); + } + + private boolean isStageCompleted(SqlStageExecution stage) + { + StageState state = stage.getState(); + return state == SCHEDULED || state == RUNNING || state == PENDING || state.isDone(); + } + + private Set extractDependenciesAndReturnNonLazyFragments(Collection stages) + { + if (stages.isEmpty()) { + return ImmutableSet.of(); + } + + QueryId queryId = stages.stream() + .map(stage -> stage.getStageId().getQueryId()) + .findAny().orElseThrow(() -> new PrestoException(GENERIC_INTERNAL_ERROR, "")); + Collection fragments = stages.stream() + .map(SqlStageExecution::getFragment) + .collect(toImmutableList()); + // Build a graph where the plan fragments are vertexes and the edges represent + // a before -> after relationship. Destination fragment should be started only + // when source fragment is completed. For example, a join hash build has an edge + // to the join probe. + Visitor visitor = new Visitor(queryId, fragments); + visitor.processAllFragments(); + + // Make sure there are no strongly connected components as it would mean circular dependency between stages + List> components = new StrongConnectivityInspector<>(fragmentDependency).stronglyConnectedSets(); + verify(components.size() == fragmentDependency.vertexSet().size(), "circular dependency between stages"); + + return visitor.getNonLazyFragments(); + } + + private class Visitor + extends InternalPlanVisitor + { + private final QueryId queryId; + private final Map fragments; + private final ImmutableSet.Builder nonLazyFragments = ImmutableSet.builder(); + private final Map fragmentSubGraphs = new HashMap<>(); + + public Visitor(QueryId queryId, Collection fragments) + { + this.queryId = queryId; + this.fragments = requireNonNull(fragments, "fragments is null").stream() + .collect(toImmutableMap(PlanFragment::getId, identity())); + } + + public Set getNonLazyFragments() + { + return nonLazyFragments.build(); + } + + public void processAllFragments() + { + fragments.forEach((fragmentId, fragment) -> { + fragmentDependency.addVertex(fragmentId); + fragmentTopology.addVertex(fragmentId); + }); + fragments.forEach((fragmentId, fragment) -> processFragment(fragmentId)); + } + + public FragmentSubGraph processFragment(PlanFragmentId planFragmentId) + { + if (fragmentSubGraphs.containsKey(planFragmentId)) { + return fragmentSubGraphs.get(planFragmentId); + } + + FragmentSubGraph subGraph = processFragment(fragments.get(planFragmentId)); + verify(fragmentSubGraphs.put(planFragmentId, subGraph) == null, "fragment %s was already processed", planFragmentId); + sortedFragments.add(planFragmentId); + return subGraph; + } + + private FragmentSubGraph processFragment(PlanFragment fragment) + { + FragmentSubGraph subGraph = fragment.getRoot().accept(this, fragment.getId()); + // append current fragment to set of upstream fragments as it is no longer being visited + Set upstreamFragments = ImmutableSet.builder() + .addAll(subGraph.getUpstreamFragments()) + .add(fragment.getId()) + .build(); + Set lazyUpstreamFragments; + if (subGraph.isCurrentFragmentLazy()) { + // append current fragment as a lazy fragment as it is no longer being visited + lazyUpstreamFragments = ImmutableSet.builder() + .addAll(subGraph.getLazyUpstreamFragments()) + .add(fragment.getId()) + .build(); + } + else { + lazyUpstreamFragments = subGraph.getLazyUpstreamFragments(); + nonLazyFragments.add(fragment.getId()); + } + return new FragmentSubGraph( + upstreamFragments, + lazyUpstreamFragments, + // no longer relevant as we have finished visiting given fragment + false); + } + + @Override + public FragmentSubGraph visitJoin(JoinNode node, PlanFragmentId currentFragmentId) + { + return processJoin( + node.getDistributionType().orElseThrow(() -> new NoSuchElementException("No Value Present")) == JoinNode.DistributionType.REPLICATED, + node.getLeft(), + node.getRight(), + currentFragmentId); + } + + @Override + public FragmentSubGraph visitSpatialJoin(SpatialJoinNode node, PlanFragmentId currentFragmentId) + { + return processJoin( + node.getDistributionType() == SpatialJoinNode.DistributionType.REPLICATED, + node.getLeft(), + node.getRight(), + currentFragmentId); + } + + @Override + public FragmentSubGraph visitSemiJoin(SemiJoinNode node, PlanFragmentId currentFragmentId) + { + return processJoin( + node.getDistributionType().orElseThrow(() -> new NoSuchElementException("No Value Present")) == SemiJoinNode.DistributionType.REPLICATED, + node.getSource(), + node.getFilteringSource(), + currentFragmentId); + } + + @Override + public FragmentSubGraph visitIndexJoin(IndexJoinNode node, PlanFragmentId currentFragmentId) + { + return processJoin( + true, + node.getProbeSource(), + node.getIndexSource(), + currentFragmentId); + } + + private FragmentSubGraph processJoin(boolean replicated, PlanNode probe, PlanNode build, PlanFragmentId currentFragmentId) + { + FragmentSubGraph probeSubGraph = probe.accept(this, currentFragmentId); + FragmentSubGraph buildSubGraph = build.accept(this, currentFragmentId); + + // start probe source stages after all build source stages finish + addDependencyEdges(buildSubGraph.getUpstreamFragments(), probeSubGraph.getLazyUpstreamFragments()); + + boolean currentFragmentLazy = probeSubGraph.isCurrentFragmentLazy() && buildSubGraph.isCurrentFragmentLazy(); + if (replicated && currentFragmentLazy && !dynamicFilterService.isStageSchedulingNeededToCollectDynamicFilters(queryId, fragments.get(currentFragmentId))) { + // Do not start join stage (which can also be a source stage with table scans) + // for replicated join until build source stage enters FLUSHING state. + // Broadcast join limit for CBO is set in such a way that build source data should + // fit into task output buffer. + // In case build source stage is blocked on full task buffer then join stage + // will be started automatically regardless od dependency. This is handled by + // unblockStagesWithFullOutputBuffer method. + addDependencyEdges(buildSubGraph.getUpstreamFragments(), ImmutableSet.of(currentFragmentId)); + } + else { + // start current fragment immediately since for partitioned join + // build source data won't be able to fit into task output buffer. + currentFragmentLazy = false; + } + + return new FragmentSubGraph( + ImmutableSet.builder() + .addAll(probeSubGraph.getUpstreamFragments()) + .addAll(buildSubGraph.getUpstreamFragments()) + .build(), + // only probe source fragments can be considered lazy + // since build source stages should be started immediately + probeSubGraph.getLazyUpstreamFragments(), + currentFragmentLazy); + } + + @Override + public FragmentSubGraph visitAggregation(AggregationNode node, PlanFragmentId currentFragmentId) + { + FragmentSubGraph subGraph = node.getSource().accept(this, currentFragmentId); + if (node.getStep() != FINAL && node.getStep() != SINGLE) { + return subGraph; + } + + // start current fragment immediately since final/single aggregation will fully + // consume input before producing output data (aggregation shouldn't get blocked) + return new FragmentSubGraph( + subGraph.getUpstreamFragments(), + ImmutableSet.of(), + false); + } + + @Override + public FragmentSubGraph visitRemoteSource(RemoteSourceNode node, PlanFragmentId currentFragmentId) + { + List subGraphs = node.getSourceFragmentIds().stream() + .map(this::processFragment) + .collect(toImmutableList()); + node.getSourceFragmentIds() + .forEach(sourceFragmentId -> fragmentTopology.addEdge(sourceFragmentId, currentFragmentId)); + return new FragmentSubGraph( + subGraphs.stream() + .flatMap(source -> source.getUpstreamFragments().stream()) + .collect(toImmutableSet()), + subGraphs.stream() + .flatMap(source -> source.getLazyUpstreamFragments().stream()) + .collect(toImmutableSet()), + // initially current fragment is considered to be lazy unless there exist + // an operator that can fully consume input data without producing any output + // (e.g. final aggregation) + true); + } + + @Override + public FragmentSubGraph visitExchange(ExchangeNode node, PlanFragmentId currentFragmentId) + { + checkArgument(node.getScope() == LOCAL, "Only local exchanges are supported in the prioritize utilization scheduler"); + return visitPlan(node, currentFragmentId); + } + + @Override + public FragmentSubGraph visitPlan(PlanNode node, PlanFragmentId currentFragmentId) + { + List sourceSubGraphs = node.getSources().stream() + .map(subPlanNode -> subPlanNode.accept(this, currentFragmentId)) + .collect(toImmutableList()); + + return new FragmentSubGraph( + sourceSubGraphs.stream() + .flatMap(source -> source.getUpstreamFragments().stream()) + .collect(toImmutableSet()), + sourceSubGraphs.stream() + .flatMap(source -> source.getLazyUpstreamFragments().stream()) + .collect(toImmutableSet()), + sourceSubGraphs.stream() + .allMatch(FragmentSubGraph::isCurrentFragmentLazy)); + } + + private void addDependencyEdges(Set sourceFragments, Set targetFragments) + { + for (PlanFragmentId targetFragment : targetFragments) { + for (PlanFragmentId sourceFragment : sourceFragments) { + fragmentDependency.addEdge(sourceFragment, targetFragment); + } + } + } + } + + private static class FragmentSubGraph + { + /** + * All upstream fragments (excluding currently visited fragment) + */ + private final Set upstreamFragments; + /** + * All upstream lazy fragments (excluding currently visited fragment). + * Lazy fragments don't have to be started immediately. + */ + private final Set lazyUpstreamFragments; + /** + * Is currently visited fragment lazy? + */ + private final boolean currentFragmentLazy; + + public FragmentSubGraph( + Set upstreamFragments, + Set lazyUpstreamFragments, + boolean currentFragmentLazy) + { + this.upstreamFragments = requireNonNull(upstreamFragments, "upstreamFragments is null"); + this.lazyUpstreamFragments = requireNonNull(lazyUpstreamFragments, "lazyUpstreamFragments is null"); + this.currentFragmentLazy = currentFragmentLazy; + } + + public Set getUpstreamFragments() + { + return upstreamFragments; + } + + public Set getLazyUpstreamFragments() + { + return lazyUpstreamFragments; + } + + public boolean isCurrentFragmentLazy() + { + return currentFragmentLazy; + } + } + + private static class FragmentsEdgeFactory + implements EdgeFactory + { + @Override + public FragmentsEdge createEdge(PlanFragmentId sourceVertex, PlanFragmentId targetVertex) + { + return new FragmentsEdge(sourceVertex, targetVertex); + } + } + + @VisibleForTesting + static class FragmentsEdge + { + private final PlanFragmentId source; + private final PlanFragmentId target; + + public FragmentsEdge(PlanFragmentId source, PlanFragmentId target) + { + this.source = requireNonNull(source, "source is null"); + this.target = requireNonNull(target, "target is null"); + } + + public PlanFragmentId getSource() + { + return source; + } + + public PlanFragmentId getTarget() + { + return target; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("source", source) + .add("target", target) + .toString(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FragmentsEdge that = (FragmentsEdge) o; + return source.equals(that.source) && target.equals(that.target); + } + + @Override + public int hashCode() + { + return Objects.hash(source, target); + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/StagesScheduleResult.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/StagesScheduleResult.java new file mode 100644 index 000000000..75ebe9679 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/policy/StagesScheduleResult.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution.scheduler.policy; + +import com.google.common.util.concurrent.ListenableFuture; +import io.prestosql.execution.SqlStageExecution; + +import java.util.Optional; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +public class StagesScheduleResult +{ + private final Set stagesToSchedule; + private final Optional> rescheduleFuture; + + public StagesScheduleResult(Set stagesToSchedule) + { + this(stagesToSchedule, Optional.empty()); + } + + public StagesScheduleResult(Set stagesToSchedule, Optional> rescheduleFuture) + { + this.stagesToSchedule = requireNonNull(stagesToSchedule, "stagesToSchedule is null"); + this.rescheduleFuture = requireNonNull(rescheduleFuture, "rescheduleFuture is null"); + } + + public Set getStagesToSchedule() + { + return stagesToSchedule; + } + + public Optional> getRescheduleFuture() + { + return rescheduleFuture; + } +} diff --git a/presto-main/src/main/java/io/prestosql/operator/BigintGroupByHash.java b/presto-main/src/main/java/io/prestosql/operator/BigintGroupByHash.java index 5c719de5e..51f4f81cb 100644 --- a/presto-main/src/main/java/io/prestosql/operator/BigintGroupByHash.java +++ b/presto-main/src/main/java/io/prestosql/operator/BigintGroupByHash.java @@ -29,9 +29,11 @@ import org.openjdk.jol.info.ClassLayout; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.slice.SizeOf.sizeOf; import static io.prestosql.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.type.TypeUtils.NULL_HASH_CODE; @@ -51,8 +53,8 @@ public class BigintGroupByHash private int mask; // the hash table from values to groupIds - private LongBigArray values; - private IntBigArray groupIds; + private long[] values; + private int[] groupIds; // reverse index from the groupId back to the value private final LongBigArray valuesByGroupId; @@ -74,10 +76,9 @@ public BigintGroupByHash(int hashChannel, boolean outputRawHash, int expectedSiz maxFill = calculateMaxFill(hashCapacity); mask = hashCapacity - 1; - values = new LongBigArray(); - values.ensureCapacity(hashCapacity); - groupIds = new IntBigArray(-1); - groupIds.ensureCapacity(hashCapacity); + values = new long[hashCapacity]; + groupIds = new int[hashCapacity]; + Arrays.fill(groupIds, -1); valuesByGroupId = new LongBigArray(); valuesByGroupId.ensureCapacity(hashCapacity); @@ -103,8 +104,8 @@ public int getGroupCount() public long getEstimatedSize() { return INSTANCE_SIZE + - groupIds.sizeOf() + - values.sizeOf() + + sizeOf(groupIds) + + sizeOf(values) + valuesByGroupId.sizeOf() + preallocatedMemoryInBytes; } @@ -167,15 +168,15 @@ public boolean contains(int position, Page page, int[] hashChannels) } long value = BIGINT.getLong(block, position); - long hashPosition = getHashPosition(value, mask); + int hashPosition = getHashPosition(value, mask); // look for an empty slot or a slot containing this key while (true) { - int groupId = groupIds.get(hashPosition); + int groupId = groupIds[hashPosition]; if (groupId == -1) { return false; } - if (value == values.get(hashPosition)) { + if (value == values[hashPosition]) { return true; } @@ -213,16 +214,16 @@ public int putIfAbsent(int position, Block block) } long value = BIGINT.getLong(block, position); - long hashPosition = getHashPosition(value, mask); + int hashPosition = getHashPosition(value, mask); // look for an empty slot or a slot containing this key while (true) { - int groupId = groupIds.get(hashPosition); + int groupId = groupIds[hashPosition]; if (groupId == -1) { break; } - if (value == values.get(hashPosition)) { + if (value == values[hashPosition]) { return groupId; } @@ -234,14 +235,14 @@ public int putIfAbsent(int position, Block block) return addNewGroup(hashPosition, value); } - private int addNewGroup(long hashPosition, long value) + private int addNewGroup(int hashPosition, long value) { // record group id in hash int groupId = nextGroupId++; - values.set(hashPosition, value); + values[hashPosition] = value; valuesByGroupId.set(groupId, value); - groupIds.set(hashPosition, groupId); + groupIds[hashPosition] = groupId; // increase capacity, if necessary if (needMoreCapacity()) { @@ -271,10 +272,9 @@ public boolean tryToIncreaseCapacity() expectedHashCollisions += estimateNumberOfHashCollisions(getGroupCount(), hashCapacity); int newMask = newCapacity - 1; - LongBigArray newValues = new LongBigArray(); - newValues.ensureCapacity(newCapacity); - IntBigArray newGroupIds = new IntBigArray(-1); - newGroupIds.ensureCapacity(newCapacity); + long[] newValues = new long[newCapacity]; + int[] newGroupIds = new int[newCapacity]; + Arrays.fill(newGroupIds, -1); for (int groupId = 0; groupId < nextGroupId; groupId++) { if (groupId == nullGroupId) { @@ -283,15 +283,15 @@ public boolean tryToIncreaseCapacity() long value = valuesByGroupId.get(groupId); // find an empty slot for the address - long hashPosition = getHashPosition(value, newMask); - while (newGroupIds.get(hashPosition) != -1) { + int hashPosition = getHashPosition(value, newMask); + while (newGroupIds[hashPosition] != -1) { hashPosition = (hashPosition + 1) & newMask; hashCollisions++; } // record the mapping - newValues.set(hashPosition, value); - newGroupIds.set(hashPosition, groupId); + newValues[hashPosition] = value; + newGroupIds[hashPosition] = groupId; } mask = newMask; @@ -304,9 +304,9 @@ public boolean tryToIncreaseCapacity() return true; } - private static long getHashPosition(long rawHash, int mask) + private static int getHashPosition(long rawHash, int mask) { - return murmurHash3(rawHash) & mask; + return (int) (murmurHash3(rawHash) & mask); } @Override @@ -316,8 +316,8 @@ public Object capture(BlockEncodingSerdeProvider serdeProvider) myState.hashCapacity = hashCapacity; myState.maxFill = maxFill; myState.mask = mask; - myState.values = values.capture(serdeProvider); - myState.groupIds = groupIds.capture(serdeProvider); + myState.values = captureLong(); + myState.groupIds = captureInt(); myState.nullGroupId = nullGroupId; myState.valuesByGroupId = valuesByGroupId.capture(serdeProvider); myState.nextGroupId = nextGroupId; @@ -328,6 +328,44 @@ public Object capture(BlockEncodingSerdeProvider serdeProvider) return myState; } + private Object captureLong() + { + LongBigArray.LongBigArrayState myState = new LongBigArray.LongBigArrayState(); + long[] capturedArray = new long[this.values.length]; + for (int i = 0; i < this.values.length; i++) { + capturedArray[i] = this.values[i]; + } + myState.array[0] = capturedArray; + myState.capacity = this.hashCapacity; + return myState; + } + + private Object captureInt() + { + IntBigArray.IntBigArrayState myState = new IntBigArray.IntBigArrayState(); + int[] capturedArray = new int[this.groupIds.length]; + for (int i = 0; i < this.values.length; i++) { + capturedArray[i] = this.groupIds[i]; + } + myState.array[0] = capturedArray; + myState.capacity = this.hashCapacity; + return myState; + } + + private void restoreLong(Object state) + { + LongBigArray.LongBigArrayState myState = (LongBigArray.LongBigArrayState) state; + this.values = myState.array[0]; + this.hashCapacity = myState.capacity; + } + + private void restoreInt(Object state) + { + IntBigArray.IntBigArrayState myState = (IntBigArray.IntBigArrayState) state; + this.groupIds = myState.array[0]; + this.hashCapacity = myState.capacity; + } + @Override public void restore(Object state, BlockEncodingSerdeProvider serdeProvider) { @@ -335,8 +373,8 @@ public void restore(Object state, BlockEncodingSerdeProvider serdeProvider) this.hashCapacity = myState.hashCapacity; this.maxFill = myState.maxFill; this.mask = myState.mask; - this.values.restore(myState.values, serdeProvider); - this.groupIds.restore(myState.groupIds, serdeProvider); + restoreLong(myState.values); + restoreInt(myState.groupIds); this.nullGroupId = myState.nullGroupId; this.valuesByGroupId.restore(myState.valuesByGroupId, serdeProvider); this.nextGroupId = myState.nextGroupId; diff --git a/presto-main/src/main/java/io/prestosql/operator/CompletedWork.java b/presto-main/src/main/java/io/prestosql/operator/CompletedWork.java index adb885ee7..428145588 100644 --- a/presto-main/src/main/java/io/prestosql/operator/CompletedWork.java +++ b/presto-main/src/main/java/io/prestosql/operator/CompletedWork.java @@ -13,13 +13,21 @@ */ package io.prestosql.operator; +import javax.annotation.Nullable; + import static java.util.Objects.requireNonNull; public final class CompletedWork implements Work { + @Nullable private final T result; + public CompletedWork() + { + result = null; + } + public CompletedWork(T value) { this.result = requireNonNull(value); @@ -31,6 +39,7 @@ public boolean process() return true; } + @Nullable @Override public T getResult() { diff --git a/presto-main/src/main/java/io/prestosql/operator/GroupAggregationOperator.java b/presto-main/src/main/java/io/prestosql/operator/GroupAggregationOperator.java index 05d3b9467..7f561e622 100644 --- a/presto-main/src/main/java/io/prestosql/operator/GroupAggregationOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/GroupAggregationOperator.java @@ -21,6 +21,7 @@ import io.prestosql.operator.aggregation.Accumulator; import io.prestosql.operator.aggregation.AccumulatorFactory; import io.prestosql.operator.aggregation.builder.AggregationBuilder; +import io.prestosql.operator.aggregation.partial.PartialAggregationController; import io.prestosql.operator.scalar.CombineHashFunction; import io.prestosql.snapshot.SingleInputSnapshotState; import io.prestosql.spi.Page; @@ -39,6 +40,7 @@ import java.util.Optional; import java.util.stream.Collectors; +import static com.google.common.base.Preconditions.checkArgument; import static io.prestosql.operator.aggregation.builder.InMemoryHashAggregationBuilder.toTypes; import static io.prestosql.sql.planner.optimizations.HashGenerationOptimizer.INITIAL_HASH_VALUE; import static io.prestosql.type.TypeUtils.NULL_HASH_CODE; @@ -74,6 +76,7 @@ public static class GroupAggregationOperatorFactory protected final SpillerFactory spillerFactory; protected final JoinCompiler joinCompiler; protected final boolean useSystemMemory; + protected final Optional partialAggregationController; protected boolean closed; @@ -94,7 +97,8 @@ public GroupAggregationOperatorFactory( DataSize unspillMemoryLimit, SpillerFactory spillerFactory, JoinCompiler joinCompiler, - boolean useSystemMemory) + boolean useSystemMemory, + Optional partialAggregationController) { this(operatorId, planNodeId, @@ -113,7 +117,8 @@ public GroupAggregationOperatorFactory( DataSize.succinctBytes((long) (unspillMemoryLimit.toBytes() * MERGE_WITH_MEMORY_RATIO)), spillerFactory, joinCompiler, - useSystemMemory); + useSystemMemory, + partialAggregationController); } public GroupAggregationOperatorFactory( @@ -134,7 +139,8 @@ public GroupAggregationOperatorFactory( DataSize memoryLimitForMergeWithMemory, SpillerFactory spillerFactory, JoinCompiler joinCompiler, - boolean useSystemMemory) + boolean useSystemMemory, + Optional partialAggregationController) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); @@ -154,6 +160,7 @@ public GroupAggregationOperatorFactory( this.spillerFactory = requireNonNull(spillerFactory, "spillerFactory is null"); this.joinCompiler = requireNonNull(joinCompiler, "joinCompiler is null"); this.useSystemMemory = useSystemMemory; + this.partialAggregationController = requireNonNull(partialAggregationController, "partialAggregationController is null"); } @Override @@ -177,6 +184,7 @@ public OperatorFactory duplicate() protected static final int pageFinalizeLocation = 2; protected final OperatorContext operatorContext; + protected final Optional partialAggregationController; protected final SingleInputSnapshotState snapshotState; protected final List groupByTypes; protected final List groupByChannels; @@ -206,6 +214,8 @@ public OperatorFactory duplicate() // for yield when memory is not available protected Work unfinishedWork; + protected long numberOfInputRowsProcessed; + protected long numberOfUniqueRowsProduced; public GroupAggregationOperator( OperatorContext operatorContext, @@ -224,12 +234,15 @@ public GroupAggregationOperator( DataSize memoryLimitForMergeWithMemory, SpillerFactory spillerFactory, JoinCompiler joinCompiler, - boolean useSystemMemory) + boolean useSystemMemory, + Optional partialAggregationController) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); + this.partialAggregationController = requireNonNull(partialAggregationController, "partialAggregationControl is null"); requireNonNull(step, "step is null"); requireNonNull(accumulatorFactories, "accumulatorFactories is null"); requireNonNull(operatorContext, "operatorContext is null"); + checkArgument(!partialAggregationController.isPresent() || step.isOutputPartial(), "partialAggregationController should be present only for partial aggregation"); this.snapshotState = operatorContext.isSnapshotEnabled() ? SingleInputSnapshotState.forOperator(this, operatorContext) : null; this.groupByTypes = ImmutableList.copyOf(groupByTypes); diff --git a/presto-main/src/main/java/io/prestosql/operator/HashAggregationOperator.java b/presto-main/src/main/java/io/prestosql/operator/HashAggregationOperator.java index 1f4f57d66..2a43b7e33 100644 --- a/presto-main/src/main/java/io/prestosql/operator/HashAggregationOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/HashAggregationOperator.java @@ -18,6 +18,8 @@ import io.prestosql.operator.aggregation.AccumulatorFactory; import io.prestosql.operator.aggregation.builder.InMemoryHashAggregationBuilder; import io.prestosql.operator.aggregation.builder.SpillableHashAggregationBuilder; +import io.prestosql.operator.aggregation.partial.PartialAggregationController; +import io.prestosql.operator.aggregation.partial.SkipAggregationBuilder; import io.prestosql.spi.Page; import io.prestosql.spi.plan.AggregationNode.Step; import io.prestosql.spi.plan.PlanNodeId; @@ -53,7 +55,8 @@ public HashAggregationOperatorFactory( int expectedGroups, Optional maxPartialMemory, JoinCompiler joinCompiler, - boolean useSystemMemory) + boolean useSystemMemory, + Optional partialAggregationController) { this(operatorId, planNodeId, @@ -74,7 +77,8 @@ public HashAggregationOperatorFactory( throw new UnsupportedOperationException(); }, joinCompiler, - useSystemMemory); + useSystemMemory, + partialAggregationController); } public HashAggregationOperatorFactory( @@ -94,7 +98,8 @@ public HashAggregationOperatorFactory( DataSize unspillMemoryLimit, SpillerFactory spillerFactory, JoinCompiler joinCompiler, - boolean useSystemMemory) + boolean useSystemMemory, + Optional partialAggregationController) { this(operatorId, planNodeId, @@ -113,7 +118,8 @@ public HashAggregationOperatorFactory( DataSize.succinctBytes((long) (unspillMemoryLimit.toBytes() * MERGE_WITH_MEMORY_RATIO)), spillerFactory, joinCompiler, - useSystemMemory); + useSystemMemory, + partialAggregationController); } @VisibleForTesting @@ -135,7 +141,8 @@ public HashAggregationOperatorFactory( DataSize memoryLimitForMergeWithMemory, SpillerFactory spillerFactory, JoinCompiler joinCompiler, - boolean useSystemMemory) + boolean useSystemMemory, + Optional partialAggregationController) { super( operatorId, @@ -155,7 +162,8 @@ public HashAggregationOperatorFactory( memoryLimitForMergeWithMemory, spillerFactory, joinCompiler, - useSystemMemory); + useSystemMemory, + partialAggregationController); } @Override @@ -181,7 +189,8 @@ public Operator createOperator(DriverContext driverContext) memoryLimitForMergeWithMemory, spillerFactory, joinCompiler, - useSystemMemory); + useSystemMemory, + partialAggregationController); return hashAggregationOperator; } @@ -212,7 +221,8 @@ public OperatorFactory duplicate() memoryLimitForMergeWithMemory, spillerFactory, joinCompiler, - useSystemMemory); + useSystemMemory, + partialAggregationController.map(PartialAggregationController::duplicate)); } } @@ -233,7 +243,8 @@ public HashAggregationOperator( DataSize memoryLimitForMergeWithMemory, SpillerFactory spillerFactory, JoinCompiler joinCompiler, - boolean useSystemMemory) + boolean useSystemMemory, + Optional partialAggregationController) { super(operatorContext, groupByTypes, @@ -251,7 +262,8 @@ public HashAggregationOperator( memoryLimitForMergeWithMemory, spillerFactory, joinCompiler, - useSystemMemory); + useSystemMemory, + partialAggregationController); this.hashCollisionsCounter = new HashCollisionsCounter(operatorContext); operatorContext.setInfoSupplier(hashCollisionsCounter); @@ -290,14 +302,21 @@ public void addInput(Page page) unfinishedWork = null; } aggregationBuilder.updateMemory(); + numberOfInputRowsProcessed += page.getPositionCount(); } @Override public void createAggregationBuilder() { - // TODO: We ignore spillEnabled here if any aggregate has ORDER BY clause or DISTINCT because they are not yet implemented for spilling. - if (step.isOutputPartial() || !spillEnabled || hasOrderBy() || hasDistinct()) { - //TODO-cp-I39B76 snapshot support + boolean partialAggregationDisabled = partialAggregationController + .map(PartialAggregationController::isPartialAggregationDisabled) + .orElse(false); + if (step.isOutputPartial() && partialAggregationDisabled) { + aggregationBuilder = new SkipAggregationBuilder(groupByChannels, hashChannel, accumulatorFactories, memoryContext); + } + else if (step.isOutputPartial() || !spillEnabled || hasOrderBy() || hasDistinct()) { + // TODO: We ignore spillEnabled here if any aggregate has ORDER BY clause or DISTINCT because they are not yet implemented for spilling. + // TODO-cp-I39B76 snapshot support aggregationBuilder = new InMemoryHashAggregationBuilder( accumulatorFactories, step, @@ -391,7 +410,9 @@ public Page getOutput() return null; } - return outputPages.getResult(); + Page result = outputPages.getResult(); + numberOfUniqueRowsProduced += result.getPositionCount(); + return result; } @Override @@ -414,5 +435,9 @@ protected void closeAggregationBuilder() aggregationBuilder = null; } memoryContext.setBytes(0); + partialAggregationController.ifPresent( + controller -> controller.onFlush(numberOfInputRowsProcessed, numberOfUniqueRowsProduced)); + numberOfInputRowsProcessed = 0; + numberOfUniqueRowsProduced = 0; } } diff --git a/presto-main/src/main/java/io/prestosql/operator/OperatorInfo.java b/presto-main/src/main/java/io/prestosql/operator/OperatorInfo.java index f80d72c90..129ea319f 100644 --- a/presto-main/src/main/java/io/prestosql/operator/OperatorInfo.java +++ b/presto-main/src/main/java/io/prestosql/operator/OperatorInfo.java @@ -15,9 +15,9 @@ import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; -import io.prestosql.operator.PartitionedOutputOperator.PartitionedOutputInfo; import io.prestosql.operator.TableWriterOperator.TableWriterInfo; import io.prestosql.operator.exchange.LocalExchangeBufferInfo; +import io.prestosql.operator.output.PartitionedOutputOperator.PartitionedOutputInfo; @JsonTypeInfo( use = JsonTypeInfo.Id.NAME, diff --git a/presto-main/src/main/java/io/prestosql/operator/SortAggregationOperator.java b/presto-main/src/main/java/io/prestosql/operator/SortAggregationOperator.java index daeb68eb2..30da48711 100644 --- a/presto-main/src/main/java/io/prestosql/operator/SortAggregationOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/SortAggregationOperator.java @@ -21,6 +21,7 @@ import io.prestosql.operator.aggregation.builder.InMemoryHashAggregationBuilder; import io.prestosql.operator.aggregation.builder.InMemorySortAggregationBuilder; import io.prestosql.operator.aggregation.builder.SpillableHashAggregationBuilder; +import io.prestosql.operator.aggregation.partial.PartialAggregationController; import io.prestosql.spi.Page; import io.prestosql.spi.plan.AggregationNode; import io.prestosql.spi.plan.PlanNodeId; @@ -56,11 +57,11 @@ public SortAggregationOperatorFactory(int operatorId, PlanNodeId planNodeId, Lis Optional groupIdChannel, int expectedGroups, Optional maxPartialMemory, boolean spillEnabled, DataSize unspillMemoryLimit, SpillerFactory spillerFactory, - JoinCompiler joinCompiler, boolean useSystemMemory, boolean isFinalizedValuePresent) + JoinCompiler joinCompiler, boolean useSystemMemory, boolean isFinalizedValuePresent, Optional partialAggregationController) { super(operatorId, planNodeId, groupByTypes, groupByChannels, globalAggregationGroupIds, step, produceDefaultOutput, accumulatorFactories, hashChannel, groupIdChannel, expectedGroups, maxPartialMemory, spillEnabled, - unspillMemoryLimit, spillerFactory, joinCompiler, useSystemMemory); + unspillMemoryLimit, spillerFactory, joinCompiler, useSystemMemory, partialAggregationController); this.isFinalizedValuePresent = isFinalizedValuePresent; } @@ -83,7 +84,8 @@ public SortAggregationOperatorFactory(int operatorId, PlanNodeId planNodeId, Lis SpillerFactory spillerFactory, JoinCompiler joinCompiler, boolean useSystemMemory, - boolean isFinalizedValuePresent) + boolean isFinalizedValuePresent, + Optional partialAggregationController) { super( operatorId, @@ -103,7 +105,8 @@ public SortAggregationOperatorFactory(int operatorId, PlanNodeId planNodeId, Lis memoryLimitForMergeWithMemory, spillerFactory, joinCompiler, - useSystemMemory); + useSystemMemory, + partialAggregationController); this.isFinalizedValuePresent = isFinalizedValuePresent; } @@ -131,7 +134,8 @@ public Operator createOperator(DriverContext driverContext) spillerFactory, joinCompiler, useSystemMemory, - isFinalizedValuePresent); + isFinalizedValuePresent, + partialAggregationController); return sortAggregationOperator; } @@ -157,7 +161,8 @@ public OperatorFactory duplicate() spillerFactory, joinCompiler, useSystemMemory, - isFinalizedValuePresent); + isFinalizedValuePresent, + partialAggregationController.map(PartialAggregationController::duplicate)); } @Override @@ -178,11 +183,11 @@ public SortAggregationOperator(OperatorContext operatorContext, List group Optional maxPartialMemory, boolean spillEnabled, DataSize memoryLimitForMerge, DataSize memoryLimitForMergeWithMemory, SpillerFactory spillerFactory, JoinCompiler joinCompiler, boolean useSystemMemory, - boolean isFinalizedValuePresent) + boolean isFinalizedValuePresent, Optional partialAggregationController) { super(operatorContext, groupByTypes, groupByChannels, globalAggregationGroupIds, step, produceDefaultOutput, accumulatorFactories, hashChannel, groupIdChannel, expectedGroups, maxPartialMemory, spillEnabled, - memoryLimitForMerge, memoryLimitForMergeWithMemory, spillerFactory, joinCompiler, useSystemMemory); + memoryLimitForMerge, memoryLimitForMergeWithMemory, spillerFactory, joinCompiler, useSystemMemory, partialAggregationController); this.isFinalizedValuePresent = isFinalizedValuePresent; } diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/partial/PartialAggregationController.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/partial/PartialAggregationController.java new file mode 100644 index 000000000..fcfad0060 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/partial/PartialAggregationController.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.aggregation.partial; + +import io.prestosql.operator.HashAggregationOperator; + +/** + * Controls whenever partial aggregation is enabled across all {@link HashAggregationOperator}s + * for a particular plan node on a single node. + * Partial aggregation is disabled once enough rows has been processed ({@link #minNumberOfRowsProcessed}) + * and the ratio between output(unique) and input rows is too high (> {@link #uniqueRowsRatioThreshold}). + * TODO https://github.com/trinodb/trino/issues/11361 add support to adaptively re-enable partial aggregation. + *

+ * The class is thread safe and objects of this class are used potentially by multiple threads/drivers simultaneously. + * Different threads either: + * - modify fields via synchronized {@link #onFlush}. + * - read volatile {@link #partialAggregationDisabled} (volatile here gives visibility). + */ +public class PartialAggregationController +{ + private final long minNumberOfRowsProcessed; + private final double uniqueRowsRatioThreshold; + + private volatile boolean partialAggregationDisabled; + private long totalRowProcessed; + private long totalUniqueRowsProduced; + + public PartialAggregationController(long minNumberOfRowsProcessedToDisable, double uniqueRowsRatioThreshold) + { + this.minNumberOfRowsProcessed = minNumberOfRowsProcessedToDisable; + this.uniqueRowsRatioThreshold = uniqueRowsRatioThreshold; + } + + public boolean isPartialAggregationDisabled() + { + return partialAggregationDisabled; + } + + public synchronized void onFlush(long rowsProcessed, long uniqueRowsProduced) + { + if (partialAggregationDisabled) { + return; + } + + totalRowProcessed += rowsProcessed; + totalUniqueRowsProduced += uniqueRowsProduced; + if (shouldDisablePartialAggregation()) { + partialAggregationDisabled = true; + } + } + + private boolean shouldDisablePartialAggregation() + { + return totalRowProcessed >= minNumberOfRowsProcessed + && ((double) totalUniqueRowsProduced / totalRowProcessed) > uniqueRowsRatioThreshold; + } + + public PartialAggregationController duplicate() + { + return new PartialAggregationController(minNumberOfRowsProcessed, uniqueRowsRatioThreshold); + } +} diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/partial/SkipAggregationBuilder.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/partial/SkipAggregationBuilder.java new file mode 100644 index 000000000..c7beb4afb --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/partial/SkipAggregationBuilder.java @@ -0,0 +1,190 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.aggregation.partial; + +import com.google.common.util.concurrent.ListenableFuture; +import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.operator.CompletedWork; +import io.prestosql.operator.GroupByIdBlock; +import io.prestosql.operator.HashCollisionsCounter; +import io.prestosql.operator.Work; +import io.prestosql.operator.WorkProcessor; +import io.prestosql.operator.aggregation.AccumulatorFactory; +import io.prestosql.operator.aggregation.GroupedAccumulator; +import io.prestosql.operator.aggregation.builder.AggregationBuilder; +import io.prestosql.spi.Page; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.LongArrayBlock; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +/** + * {@link AggregationBuilder} that does not aggregate input rows at all. + * It passes the input pages, augmented with initial accumulator state to the output. + * It can only be used at the partial aggregation step as it relies on rows be aggregated at the final step. + */ +public class SkipAggregationBuilder + implements AggregationBuilder +{ + private final LocalMemoryContext memoryContext; + private final List groupedAccumulators; + @Nullable + private Page currentPage; + private final int[] hashChannels; + + public SkipAggregationBuilder( + List groupByChannels, + Optional inputHashChannel, + List accumulatorFactories, + LocalMemoryContext memoryContext) + { + this.memoryContext = requireNonNull(memoryContext, "memoryContext is null"); + this.groupedAccumulators = requireNonNull(accumulatorFactories, "aggregatorFactories is null") + .stream() + .map(AccumulatorFactory::createGroupedAccumulator) + .collect(toImmutableList()); + this.hashChannels = new int[groupByChannels.size() + (inputHashChannel.isPresent() ? 1 : 0)]; + for (int i = 0; i < groupByChannels.size(); i++) { + hashChannels[i] = groupByChannels.get(i); + } + inputHashChannel.ifPresent(channelIndex -> hashChannels[groupByChannels.size()] = channelIndex); + } + + @Override + public Work processPage(Page page) + { + checkArgument(currentPage == null); + currentPage = page; + return new CompletedWork<>(); + } + + @Override + public WorkProcessor buildResult() + { + if (currentPage == null) { + return WorkProcessor.of(); + } + + Page result = buildOutputPage(currentPage); + currentPage = null; + return WorkProcessor.of(result); + } + + @Override + public boolean isFull() + { + return currentPage != null; + } + + @Override + public void updateMemory() + { + if (currentPage != null) { + memoryContext.setBytes(currentPage.getSizeInBytes()); + } + } + + @Override + public void recordHashCollisions(HashCollisionsCounter hashCollisionsCounter) + { + // no op + } + + @Override + public void close() + { + } + + @Override + public ListenableFuture startMemoryRevoke() + { + throw new UnsupportedOperationException("startMemoryRevoke not supported for SkipAggregationBuilder"); + } + + @Override + public void finishMemoryRevoke() + { + throw new UnsupportedOperationException("finishMemoryRevoke not supported for SkipAggregationBuilder"); + } + + private Page buildOutputPage(Page page) + { + populateInitialAccumulatorState(page); + + BlockBuilder[] outputBuilders = serializeAccumulatorState(page.getPositionCount()); + + return constructOutputPage(page, outputBuilders); + } + + private void populateInitialAccumulatorState(Page page) + { + GroupByIdBlock groupByIdBlock = getGroupByIdBlock(page.getPositionCount()); + for (GroupedAccumulator groupedAccumulator : groupedAccumulators) { + groupedAccumulator.addInput(groupByIdBlock, page); + } + } + + private GroupByIdBlock getGroupByIdBlock(int positionCount) + { + return new GroupByIdBlock( + positionCount, + new LongArrayBlock(positionCount, Optional.empty(), consecutive(positionCount))); + } + + private BlockBuilder[] serializeAccumulatorState(int positionCount) + { + BlockBuilder[] outputBuilders = new BlockBuilder[groupedAccumulators.size()]; + for (int i = 0; i < outputBuilders.length; i++) { + outputBuilders[i] = groupedAccumulators.get(i).getIntermediateType().createBlockBuilder(null, positionCount); + } + + for (int position = 0; position < positionCount; position++) { + for (int i = 0; i < groupedAccumulators.size(); i++) { + GroupedAccumulator groupedAccumulator = groupedAccumulators.get(i); + BlockBuilder output = outputBuilders[i]; + groupedAccumulator.evaluateIntermediate(position, output); + } + } + return outputBuilders; + } + + private Page constructOutputPage(Page page, BlockBuilder[] outputBuilders) + { + Block[] outputBlocks = new Block[hashChannels.length + outputBuilders.length]; + for (int i = 0; i < hashChannels.length; i++) { + outputBlocks[i] = page.getBlock(hashChannels[i]); + } + for (int i = 0; i < outputBuilders.length; i++) { + outputBlocks[hashChannels.length + i] = outputBuilders[i].build(); + } + return new Page(page.getPositionCount(), outputBlocks); + } + + private static long[] consecutive(int positionCount) + { + long[] longs = new long[positionCount]; + for (int i = 0; i < positionCount; i++) { + longs[i] = i; + } + return longs; + } +} diff --git a/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java b/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java new file mode 100644 index 000000000..62eac9290 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/output/PagePartitioner.java @@ -0,0 +1,507 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.output; + +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.units.DataSize; +import io.hetu.core.transport.execution.buffer.PagesSerde; +import io.hetu.core.transport.execution.buffer.SerializedPage; +import io.prestosql.exchange.FileSystemExchangeConfig; +import io.prestosql.execution.buffer.OutputBuffer; +import io.prestosql.operator.OperatorContext; +import io.prestosql.operator.PartitionFunction; +import io.prestosql.spi.Page; +import io.prestosql.spi.PageBuilder; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.DictionaryBlock; +import io.prestosql.spi.block.RunLengthEncodedBlock; +import io.prestosql.spi.predicate.NullableValue; +import io.prestosql.spi.snapshot.BlockEncodingSerdeProvider; +import io.prestosql.spi.snapshot.Restorable; +import io.prestosql.spi.snapshot.RestorableConfig; +import io.prestosql.spi.type.Type; +import it.unimi.dsi.fastutil.ints.IntArrayList; +import it.unimi.dsi.fastutil.ints.IntList; + +import java.io.Serializable; +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.IntUnaryOperator; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.prestosql.execution.buffer.PageSplitterUtil.splitPage; +import static io.prestosql.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.util.Objects.requireNonNull; + +@RestorableConfig(stateClassName = "PagePartitionerState", uncapturedFields = {"outputBuffer", "sourceTypes", "partitionFunction", "partitionChannels", + "partitionConstants", "operatorContext", "pageBuilders"}) +public class PagePartitioner + implements Restorable +{ + private static final int COLUMNAR_STRATEGY_COEFFICIENT = 4; + final String id; + //shared field + final OutputBuffer outputBuffer; + private final Type[] sourceTypes; + private final PartitionFunction partitionFunction; + private final List partitionChannels; + private final List> partitionConstants; + private final PageBuilder[] pageBuilders; + private final boolean replicatesAnyRow; + private final int nullChannel; // when >= 0, send the position to every partition if this channel is null + private final AtomicLong rowsAdded = new AtomicLong(); + private final AtomicLong pagesAdded = new AtomicLong(); + private boolean hasAnyRowBeenReplicated; + private final OperatorContext operatorContext; + private final PositionsAppenderFactory positionsAppenderFactory; + private final PositionsAppender[] positionsAppenders; + + public PagePartitioner( + String id, + PartitionFunction partitionFunction, + List partitionChannels, + List> partitionConstants, + boolean replicatesAnyRow, + OptionalInt nullChannel, + OutputBuffer outputBuffer, + OperatorContext operatorContext, + List sourceTypes, + DataSize maxMemory, + PositionsAppenderFactory positionsAppenderFactory) + { + this.id = id; + this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null"); + this.partitionChannels = requireNonNull(partitionChannels, "partitionChannels is null"); + this.positionsAppenderFactory = requireNonNull(positionsAppenderFactory, "positionsAppenderFactory is null"); + this.partitionConstants = requireNonNull(partitionConstants, "partitionConstants is null").stream() + .map(constant -> constant.map(NullableValue::asBlock)) + .collect(toImmutableList()); + this.replicatesAnyRow = replicatesAnyRow; + this.nullChannel = requireNonNull(nullChannel, "nullChannel is null").orElse(-1); + this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); + this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null").toArray(new Type[0]); + this.operatorContext = requireNonNull(operatorContext, "serde is null"); + + int partitionCount = partitionFunction.getPartitionCount(); + int pageSize = min(DEFAULT_MAX_PAGE_SIZE_IN_BYTES, ((int) maxMemory.toBytes()) / partitionCount); + pageSize = max(1, pageSize); + + this.pageBuilders = new PageBuilder[partitionCount]; + for (int i = 0; i < partitionCount; i++) { + pageBuilders[i] = PageBuilder.withMaxPageSize(pageSize, sourceTypes); + } + positionsAppenders = new PositionsAppender[sourceTypes.size()]; + } + + public ListenableFuture isFull() + { + return outputBuffer.isFull(); + } + + public long getSizeInBytes() + { + // We use a foreach loop instead of streams + // as it has much better performance. + long sizeInBytes = 0; + for (PageBuilder pageBuilder : pageBuilders) { + sizeInBytes += pageBuilder.getSizeInBytes(); + } + return sizeInBytes; + } + + /** + * This method can be expensive for complex types. + */ + public long getRetainedSizeInBytes() + { + long sizeInBytes = 0; + for (PageBuilder pageBuilder : pageBuilders) { + sizeInBytes += pageBuilder.getRetainedSizeInBytes(); + } + return sizeInBytes; + } + + public PartitionedOutputOperator.PartitionedOutputInfo getInfo() + { + return new PartitionedOutputOperator.PartitionedOutputInfo(rowsAdded.get(), pagesAdded.get(), outputBuffer.getPeakMemoryUsage()); + } + + public void partitionPage(Page page) + { + if (page.getPositionCount() == 0) { + return; + } + + if (page.getPositionCount() < partitionFunction.getPartitionCount() * COLUMNAR_STRATEGY_COEFFICIENT) { + // Partition will have on average less than COLUMNAR_STRATEGY_COEFFICIENT rows. + // Doing it column-wise would degrade performance, so we fall back to row-wise approach. + // Performance degradation is the worst in case of skewed hash distribution when only small subset + // of partitions is selected. + partitionPageByRow(page); + } + else { + partitionPageByColumn(page); + } + } + + public void partitionPageByRow(Page page) + { + requireNonNull(page, "page is null"); + if (page.getPositionCount() == 0) { + return; + } + + int position; + // Handle "any row" replication outside of the inner loop processing + if (replicatesAnyRow && !hasAnyRowBeenReplicated) { + for (PageBuilder pageBuilder : pageBuilders) { + appendRow(pageBuilder, page, 0); + } + hasAnyRowBeenReplicated = true; + position = 1; + } + else { + position = 0; + } + + Page partitionFunctionArgs = getPartitionFunctionArguments(page); + // Skip null block checks if mayHaveNull reports that no positions will be null + if (nullChannel >= 0 && page.getBlock(nullChannel).mayHaveNull()) { + Block nullsBlock = page.getBlock(nullChannel); + for (; position < page.getPositionCount(); position++) { + if (nullsBlock.isNull(position)) { + for (PageBuilder pageBuilder : pageBuilders) { + appendRow(pageBuilder, page, position); + } + } + else { + int partition = partitionFunction.getPartition(partitionFunctionArgs, position); + appendRow(pageBuilders[partition], page, position); + } + } + } + else { + for (; position < page.getPositionCount(); position++) { + int partition = partitionFunction.getPartition(partitionFunctionArgs, position); + appendRow(pageBuilders[partition], page, position); + } + } + + flush(false); + } + + private void appendRow(PageBuilder pageBuilder, Page page, int position) + { + pageBuilder.declarePosition(); + + for (int channel = 0; channel < sourceTypes.length; channel++) { + Type type = sourceTypes[channel]; + type.appendTo(page.getBlock(channel), position, pageBuilder.getBlockBuilder(channel)); + } + } + + public void partitionPageByColumn(Page page) + { + IntArrayList[] partitionedPositions = partitionPositions(page); + + PositionsAppender[] positionsAppenders = getAppenders(page); + + for (int i = 0; i < partitionFunction.getPartitionCount(); i++) { + IntArrayList partitionPositions = partitionedPositions[i]; + if (!partitionPositions.isEmpty()) { + appendToOutputPartition(pageBuilders[i], page, partitionPositions, positionsAppenders); + partitionPositions.clear(); + } + } + + flush(false); + } + + private PositionsAppender[] getAppenders(Page page) + { + for (int i = 0; i < positionsAppenders.length; i++) { + positionsAppenders[i] = positionsAppenderFactory.create(sourceTypes[i], page.getBlock(i).getClass()); + } + return positionsAppenders; + } + + private IntArrayList[] partitionPositions(Page page) + { + verify(page.getPositionCount() > 0, "position count is 0"); + IntArrayList[] partitionPositions = initPositions(page); + int position; + // Handle "any row" replication outside the inner loop processing + if (replicatesAnyRow && !hasAnyRowBeenReplicated) { + for (IntList partitionPosition : partitionPositions) { + partitionPosition.add(0); + } + hasAnyRowBeenReplicated = true; + position = 1; + } + else { + position = 0; + } + + Page partitionFunctionArgs = getPartitionFunctionArguments(page); + + if (partitionFunctionArgs.getChannelCount() > 0 && onlyRleBlocks(partitionFunctionArgs)) { + // we need at least one Rle block since with no blocks partition function + // can return a different value per invocation (e.g. RoundRobinBucketFunction) + partitionBySingleRleValue(page, position, partitionFunctionArgs, partitionPositions); + } + else if (partitionFunctionArgs.getChannelCount() == 1 && isDictionaryProcessingFaster(partitionFunctionArgs.getBlock(0))) { + partitionBySingleDictionary(page, position, partitionFunctionArgs, partitionPositions); + } + else { + partitionGeneric(page, position, aPosition -> partitionFunction.getPartition(partitionFunctionArgs, aPosition), partitionPositions); + } + return partitionPositions; + } + + private void appendToOutputPartition(PageBuilder outputPartition, Page page, IntArrayList positions, PositionsAppender[] positionsAppenders) + { + outputPartition.declarePositions(positions.size()); + + for (int channel = 0; channel < positionsAppenders.length; channel++) { + Block partitionBlock = page.getBlock(channel); + BlockBuilder target = outputPartition.getBlockBuilder(channel); + positionsAppenders[channel].appendTo(positions, partitionBlock, target); + } + } + + private IntArrayList[] initPositions(Page page) + { + // We allocate new arrays for every page (instead of caching them) because we don't + // want memory to explode in case there are input pages with many positions, where each page + // is assigned to a single partition entirely. + // For example this can happen for partition columns if they are represented by RLE blocks. + IntArrayList[] partitionPositions = new IntArrayList[partitionFunction.getPartitionCount()]; + for (int i = 0; i < partitionPositions.length; i++) { + partitionPositions[i] = new IntArrayList(initialPartitionSize(page.getPositionCount() / partitionFunction.getPartitionCount())); + } + return partitionPositions; + } + + private static int initialPartitionSize(int averagePositionsPerPartition) + { + // 1.1 coefficient compensates for the not perfect hash distribution. + // 32 compensates for the case when averagePositionsPerPartition is small, + // and we would see more variance in the hash distribution. + return (int) (averagePositionsPerPartition * 1.1) + 32; + } + + private boolean onlyRleBlocks(Page page) + { + for (int i = 0; i < page.getChannelCount(); i++) { + if (!(page.getBlock(i) instanceof RunLengthEncodedBlock)) { + return false; + } + } + return true; + } + + private void partitionBySingleRleValue(Page page, int position, Page partitionFunctionArgs, IntArrayList[] partitionPositions) + { + // copy all positions because all hash function args are the same for every position + if (nullChannel != -1 && page.getBlock(nullChannel).isNull(0)) { + verify(page.getBlock(nullChannel) instanceof RunLengthEncodedBlock, "null channel is not RunLengthEncodedBlock", page.getBlock(nullChannel)); + // all positions are null + int[] allPositions = integersInRange(position, page.getPositionCount()); + for (IntList partitionPosition : partitionPositions) { + partitionPosition.addElements(position, allPositions); + } + } + else { + // extract rle page to prevent JIT profile pollution + Page rlePage = extractRlePage(partitionFunctionArgs); + + int partition = partitionFunction.getPartition(rlePage, 0); + IntArrayList positions = partitionPositions[partition]; + for (int i = position; i < page.getPositionCount(); i++) { + positions.add(i); + } + } + } + + private Page extractRlePage(Page page) + { + Block[] valueBlocks = new Block[page.getChannelCount()]; + for (int channel = 0; channel < valueBlocks.length; ++channel) { + valueBlocks[channel] = ((RunLengthEncodedBlock) page.getBlock(channel)).getValue(); + } + return new Page(valueBlocks); + } + + private int[] integersInRange(int start, int endExclusive) + { + int[] array = new int[endExclusive - start]; + int current = start; + for (int i = 0; i < array.length; i++) { + array[i] = current++; + } + return array; + } + + private boolean isDictionaryProcessingFaster(Block block) + { + if (!(block instanceof DictionaryBlock)) { + return false; + } + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + // if dictionary block positionCount is greater than number of elements in the dictionary + // it will be faster to compute hash for the dictionary values only once and re-use it + // instead of recalculating it. + return dictionaryBlock.getPositionCount() > dictionaryBlock.getDictionary().getPositionCount(); + } + + private void partitionBySingleDictionary(Page page, int position, Page partitionFunctionArgs, IntArrayList[] partitionPositions) + { + DictionaryBlock dictionaryBlock = (DictionaryBlock) partitionFunctionArgs.getBlock(0); + Block dictionary = dictionaryBlock.getDictionary(); + int[] dictionaryPartitions = new int[dictionary.getPositionCount()]; + Page dictionaryPage = new Page(dictionary); + for (int i = 0; i < dictionary.getPositionCount(); i++) { + dictionaryPartitions[i] = partitionFunction.getPartition(dictionaryPage, i); + } + + partitionGeneric(page, position, aPosition -> dictionaryPartitions[dictionaryBlock.getId(aPosition)], partitionPositions); + } + + private void partitionGeneric(Page page, int position, IntUnaryOperator partitionFunction, IntArrayList[] partitionPositions) + { + // Skip null block checks if mayHaveNull reports that no positions will be null + if (nullChannel != -1 && page.getBlock(nullChannel).mayHaveNull()) { + partitionNullablePositions(page, position, partitionPositions, partitionFunction); + } + else { + partitionNotNullPositions(page, position, partitionPositions, partitionFunction); + } + } + + private IntArrayList[] partitionNullablePositions(Page page, int position, IntArrayList[] partitionPositions, IntUnaryOperator partitionFunction) + { + Block nullsBlock = page.getBlock(nullChannel); + int[] nullPositions = new int[page.getPositionCount()]; + int[] nonNullPositions = new int[page.getPositionCount()]; + int nullCount = 0; + int nonNullCount = 0; + for (int i = position; i < page.getPositionCount(); i++) { + nullPositions[nullCount] = i; + nonNullPositions[nonNullCount] = i; + int isNull = nullsBlock.isNull(i) ? 1 : 0; + nullCount += isNull; + nonNullCount += isNull ^ 1; + } + for (IntArrayList positions : partitionPositions) { + positions.addElements(position, nullPositions, 0, nullCount); + } + for (int i = 0; i < nonNullCount; i++) { + int nonNullPosition = nonNullPositions[i]; + int partition = partitionFunction.applyAsInt(nonNullPosition); + partitionPositions[partition].add(nonNullPosition); + } + return partitionPositions; + } + + private IntArrayList[] partitionNotNullPositions(Page page, int startingPosition, IntArrayList[] partitionPositions, IntUnaryOperator partitionFunction) + { + for (int position = startingPosition; position < page.getPositionCount(); position++) { + int partition = partitionFunction.applyAsInt(position); + partitionPositions[partition].add(position); + } + + return partitionPositions; + } + + private Page getPartitionFunctionArguments(Page page) + { + Block[] blocks = new Block[partitionChannels.size()]; + for (int i = 0; i < blocks.length; i++) { + Optional partitionConstant = partitionConstants.get(i); + if (partitionConstant.isPresent()) { + blocks[i] = new RunLengthEncodedBlock(partitionConstant.get(), page.getPositionCount()); + } + else { + blocks[i] = page.getBlock(partitionChannels.get(i)); + } + } + return new Page(page.getPositionCount(), blocks); + } + + public void flush(boolean force) + { + // add all full pages to output buffer + for (int partition = 0; partition < pageBuilders.length; partition++) { + PageBuilder partitionPageBuilder = pageBuilders[partition]; + if (!partitionPageBuilder.isEmpty() && (force || partitionPageBuilder.isFull())) { + Page pagePartition = partitionPageBuilder.build(); + partitionPageBuilder.reset(); + + FileSystemExchangeConfig.DirectSerialisationType serialisationType = outputBuffer.getExchangeDirectSerialisationType(); + if (outputBuffer.isSpoolingOutputBuffer() && serialisationType != FileSystemExchangeConfig.DirectSerialisationType.OFF) { + PagesSerde directSerde = (serialisationType == FileSystemExchangeConfig.DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); + List pages = splitPage(pagePartition, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + outputBuffer.enqueuePages(partition, pages, id, directSerde); + } + else { + List serializedPages = splitPage(pagePartition, DEFAULT_MAX_PAGE_SIZE_IN_BYTES).stream() + .map(page -> operatorContext.getDriverContext().getSerde().serialize(page)) + .collect(toImmutableList()); + + outputBuffer.enqueue(partition, serializedPages, id); + } + pagesAdded.incrementAndGet(); + rowsAdded.addAndGet(pagePartition.getPositionCount()); + } + } + } + + @Override + public Object capture(BlockEncodingSerdeProvider serdeProvider) + { + PagePartitionerState myState = new PagePartitionerState(); + // This was just flushed, so page builders must be empty + for (int i = 0; i < pageBuilders.length; i++) { + checkState(pageBuilders[i].isEmpty()); + } + myState.rowsAdded = rowsAdded.get(); + myState.pagesAdded = pagesAdded.get(); + myState.hasAnyRowBeenReplicated = hasAnyRowBeenReplicated; + return myState; + } + + @Override + public void restore(Object state, BlockEncodingSerdeProvider serdeProvider) + { + PagePartitionerState myState = (PagePartitionerState) state; + this.rowsAdded.set(myState.rowsAdded); + this.pagesAdded.set(myState.pagesAdded); + this.hasAnyRowBeenReplicated = myState.hasAnyRowBeenReplicated; + } + + private static class PagePartitionerState + implements Serializable + { + private long rowsAdded; + private long pagesAdded; + private boolean hasAnyRowBeenReplicated; + } +} diff --git a/presto-main/src/main/java/io/prestosql/operator/PartitionedOutputOperator.java b/presto-main/src/main/java/io/prestosql/operator/output/PartitionedOutputOperator.java similarity index 64% rename from presto-main/src/main/java/io/prestosql/operator/PartitionedOutputOperator.java rename to presto-main/src/main/java/io/prestosql/operator/output/PartitionedOutputOperator.java index bcc9552c2..51e886fd5 100644 --- a/presto-main/src/main/java/io/prestosql/operator/PartitionedOutputOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/output/PartitionedOutputOperator.java @@ -11,27 +11,30 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.prestosql.operator; +package io.prestosql.operator.output; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; -import io.hetu.core.transport.execution.buffer.PagesSerde; import io.hetu.core.transport.execution.buffer.SerializedPage; -import io.prestosql.exchange.FileSystemExchangeConfig.DirectSerialisationType; import io.prestosql.execution.buffer.OutputBuffer; import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.operator.DriverContext; +import io.prestosql.operator.Operator; +import io.prestosql.operator.OperatorContext; +import io.prestosql.operator.OperatorFactory; +import io.prestosql.operator.OperatorInfo; +import io.prestosql.operator.OutputFactory; +import io.prestosql.operator.PartitionFunction; +import io.prestosql.operator.SinkOperator; +import io.prestosql.operator.TaskContext; import io.prestosql.snapshot.SingleInputSnapshotState; import io.prestosql.spi.Page; -import io.prestosql.spi.PageBuilder; -import io.prestosql.spi.block.Block; -import io.prestosql.spi.block.RunLengthEncodedBlock; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.predicate.NullableValue; import io.prestosql.spi.snapshot.BlockEncodingSerdeProvider; import io.prestosql.spi.snapshot.MarkerPage; -import io.prestosql.spi.snapshot.Restorable; import io.prestosql.spi.snapshot.RestorableConfig; import io.prestosql.spi.type.Type; import io.prestosql.util.Mergeable; @@ -42,16 +45,10 @@ import java.util.Optional; import java.util.OptionalInt; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.prestosql.execution.buffer.PageSplitterUtil.splitPage; -import static io.prestosql.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; -import static java.lang.Math.max; -import static java.lang.Math.min; import static java.util.Objects.requireNonNull; @RestorableConfig(uncapturedFields = {"pagePreprocessor", "snapshotState"}) @@ -68,6 +65,7 @@ public static class PartitionedOutputFactory private final boolean replicatesAnyRow; private final OptionalInt nullChannel; private final DataSize maxMemory; + private final PositionsAppenderFactory positionsAppenderFactory; public PartitionedOutputFactory( PartitionFunction partitionFunction, @@ -76,7 +74,8 @@ public PartitionedOutputFactory( boolean replicatesAnyRow, OptionalInt nullChannel, OutputBuffer outputBuffer, - DataSize maxMemory) + DataSize maxMemory, + PositionsAppenderFactory positionsAppenderFactory) { this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null"); this.partitionChannels = requireNonNull(partitionChannels, "partitionChannels is null"); @@ -85,6 +84,7 @@ public PartitionedOutputFactory( this.nullChannel = requireNonNull(nullChannel, "nullChannel is null"); this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); this.maxMemory = requireNonNull(maxMemory, "maxMemory is null"); + this.positionsAppenderFactory = requireNonNull(positionsAppenderFactory, "positionsAppenderFactory is null"); } @Override @@ -107,7 +107,8 @@ public OperatorFactory createOutputOperator( replicatesAnyRow, nullChannel, outputBuffer, - maxMemory); + maxMemory, + positionsAppenderFactory); } } @@ -125,6 +126,7 @@ public static class PartitionedOutputOperatorFactory private final OptionalInt nullChannel; private final OutputBuffer outputBuffer; private final DataSize maxMemory; + private final PositionsAppenderFactory positionsAppenderFactory; // Snapshot: When a factory is duplicated, factory instances share the same OutputBuffer. // All these factory instances now share this duplicateCount, so only the last factory that receives "noMoreOperators" // (the one that decrements the count to 0) should inform OutputBuffer about "setNoMoreInputChannels". @@ -141,7 +143,8 @@ public PartitionedOutputOperatorFactory( boolean replicatesAnyRow, OptionalInt nullChannel, OutputBuffer outputBuffer, - DataSize maxMemory) + DataSize maxMemory, + PositionsAppenderFactory positionsAppenderFactory) { this( operatorId, @@ -155,7 +158,8 @@ public PartitionedOutputOperatorFactory( nullChannel, outputBuffer, maxMemory, - new AtomicInteger(1)); + new AtomicInteger(1), + positionsAppenderFactory); } private PartitionedOutputOperatorFactory( @@ -170,7 +174,8 @@ private PartitionedOutputOperatorFactory( OptionalInt nullChannel, OutputBuffer outputBuffer, DataSize maxMemory, - AtomicInteger duplicateCount) + AtomicInteger duplicateCount, + PositionsAppenderFactory positionsAppenderFactory) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); @@ -184,6 +189,7 @@ private PartitionedOutputOperatorFactory( this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); this.maxMemory = requireNonNull(maxMemory, "maxMemory is null"); this.duplicateCount = requireNonNull(duplicateCount, "duplicateCount is null"); + this.positionsAppenderFactory = requireNonNull(positionsAppenderFactory, "positionsAppenderFactory is null"); } @Override @@ -203,7 +209,8 @@ public Operator createOperator(DriverContext driverContext) replicatesAnyRow, nullChannel, outputBuffer, - maxMemory); + maxMemory, + positionsAppenderFactory); } @Override @@ -231,7 +238,8 @@ public OperatorFactory duplicate() nullChannel, outputBuffer, maxMemory, - duplicateCount); + duplicateCount, + positionsAppenderFactory); } } @@ -255,7 +263,8 @@ public PartitionedOutputOperator( boolean replicatesAnyRow, OptionalInt nullChannel, OutputBuffer outputBuffer, - DataSize maxMemory) + DataSize maxMemory, + PositionsAppenderFactory positionsAppenderFactory) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.pagePreprocessor = requireNonNull(pagePreprocessor, "pagePreprocessor is null"); @@ -269,7 +278,8 @@ public PartitionedOutputOperator( outputBuffer, operatorContext, sourceTypes, - maxMemory); + maxMemory, + positionsAppenderFactory); operatorContext.setInfoSupplier(this::getInfo); this.systemMemoryContext = operatorContext.newLocalSystemMemoryContext(PartitionedOutputOperator.class.getSimpleName()); @@ -366,200 +376,6 @@ public void addInput(Page page) systemMemoryContext.setBytes(partitionsSizeInBytes + partitionsInitialRetainedSize); } - @RestorableConfig(stateClassName = "PagePartitionerState", uncapturedFields = {"outputBuffer", "sourceTypes", "partitionFunction", "partitionChannels", - "partitionConstants", "operatorContext", "pageBuilders"}) - private static class PagePartitioner - implements Restorable - { - private final String id; - //shared field - private final OutputBuffer outputBuffer; - private final List sourceTypes; - private final PartitionFunction partitionFunction; - private final List partitionChannels; - private final List> partitionConstants; - private final PageBuilder[] pageBuilders; - private final boolean replicatesAnyRow; - private final OptionalInt nullChannel; // when present, send the position to every partition if this channel is null. - private final AtomicLong rowsAdded = new AtomicLong(); - private final AtomicLong pagesAdded = new AtomicLong(); - private boolean hasAnyRowBeenReplicated; - private final OperatorContext operatorContext; - - public PagePartitioner( - String id, - PartitionFunction partitionFunction, - List partitionChannels, - List> partitionConstants, - boolean replicatesAnyRow, - OptionalInt nullChannel, - OutputBuffer outputBuffer, - OperatorContext operatorContext, - List sourceTypes, - DataSize maxMemory) - { - this.id = id; - this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null"); - this.partitionChannels = requireNonNull(partitionChannels, "partitionChannels is null"); - this.partitionConstants = requireNonNull(partitionConstants, "partitionConstants is null").stream() - .map(constant -> constant.map(NullableValue::asBlock)) - .collect(toImmutableList()); - this.replicatesAnyRow = replicatesAnyRow; - this.nullChannel = requireNonNull(nullChannel, "nullChannel is null"); - this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); - this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null"); - this.operatorContext = requireNonNull(operatorContext, "serde is null"); - - int partitionCount = partitionFunction.getPartitionCount(); - int pageSize = min(DEFAULT_MAX_PAGE_SIZE_IN_BYTES, ((int) maxMemory.toBytes()) / partitionCount); - pageSize = max(1, pageSize); - - this.pageBuilders = new PageBuilder[partitionCount]; - for (int i = 0; i < partitionCount; i++) { - pageBuilders[i] = PageBuilder.withMaxPageSize(pageSize, sourceTypes); - } - } - - public ListenableFuture isFull() - { - return outputBuffer.isFull(); - } - - public long getSizeInBytes() - { - // We use a foreach loop instead of streams - // as it has much better performance. - long sizeInBytes = 0; - for (PageBuilder pageBuilder : pageBuilders) { - sizeInBytes += pageBuilder.getSizeInBytes(); - } - return sizeInBytes; - } - - /** - * This method can be expensive for complex types. - */ - public long getRetainedSizeInBytes() - { - long sizeInBytes = 0; - for (PageBuilder pageBuilder : pageBuilders) { - sizeInBytes += pageBuilder.getRetainedSizeInBytes(); - } - return sizeInBytes; - } - - public PartitionedOutputInfo getInfo() - { - return new PartitionedOutputInfo(rowsAdded.get(), pagesAdded.get(), outputBuffer.getPeakMemoryUsage()); - } - - public void partitionPage(Page page) - { - requireNonNull(page, "page is null"); - - Page partitionFunctionArgs = getPartitionFunctionArguments(page); - for (int position = 0; position < page.getPositionCount(); position++) { - boolean shouldReplicate = (replicatesAnyRow && !hasAnyRowBeenReplicated) || - nullChannel.isPresent() && page.getBlock(nullChannel.getAsInt()).isNull(position); - if (shouldReplicate) { - for (PageBuilder pageBuilder : pageBuilders) { - appendRow(pageBuilder, page, position); - } - hasAnyRowBeenReplicated = true; - } - else { - int partition = partitionFunction.getPartition(partitionFunctionArgs, position); - appendRow(pageBuilders[partition], page, position); - } - } - flush(false); - } - - private Page getPartitionFunctionArguments(Page page) - { - Block[] blocks = new Block[partitionChannels.size()]; - for (int i = 0; i < blocks.length; i++) { - Optional partitionConstant = partitionConstants.get(i); - if (partitionConstant.isPresent()) { - blocks[i] = new RunLengthEncodedBlock(partitionConstant.get(), page.getPositionCount()); - } - else { - blocks[i] = page.getBlock(partitionChannels.get(i)); - } - } - return new Page(page.getPositionCount(), blocks); - } - - private void appendRow(PageBuilder pageBuilder, Page page, int position) - { - pageBuilder.declarePosition(); - - for (int channel = 0; channel < sourceTypes.size(); channel++) { - Type type = sourceTypes.get(channel); - type.appendTo(page.getBlock(channel), position, pageBuilder.getBlockBuilder(channel)); - } - } - - public void flush(boolean force) - { - // add all full pages to output buffer - for (int partition = 0; partition < pageBuilders.length; partition++) { - PageBuilder partitionPageBuilder = pageBuilders[partition]; - if (!partitionPageBuilder.isEmpty() && (force || partitionPageBuilder.isFull())) { - Page pagePartition = partitionPageBuilder.build(); - partitionPageBuilder.reset(); - - DirectSerialisationType serialisationType = outputBuffer.getExchangeDirectSerialisationType(); - if (outputBuffer.isSpoolingOutputBuffer() && serialisationType != DirectSerialisationType.OFF) { - PagesSerde directSerde = (serialisationType == DirectSerialisationType.JAVA) ? operatorContext.getDriverContext().getJavaSerde() : operatorContext.getDriverContext().getKryoSerde(); - List pages = splitPage(pagePartition, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); - outputBuffer.enqueuePages(partition, pages, id, directSerde); - } - else { - List serializedPages = splitPage(pagePartition, DEFAULT_MAX_PAGE_SIZE_IN_BYTES).stream() - .map(page -> operatorContext.getDriverContext().getSerde().serialize(page)) - .collect(toImmutableList()); - - outputBuffer.enqueue(partition, serializedPages, id); - } - pagesAdded.incrementAndGet(); - rowsAdded.addAndGet(pagePartition.getPositionCount()); - } - } - } - - @Override - public Object capture(BlockEncodingSerdeProvider serdeProvider) - { - PagePartitionerState myState = new PagePartitionerState(); - // This was just flushed, so page builders must be empty - for (int i = 0; i < pageBuilders.length; i++) { - checkState(pageBuilders[i].isEmpty()); - } - myState.rowsAdded = rowsAdded.get(); - myState.pagesAdded = pagesAdded.get(); - myState.hasAnyRowBeenReplicated = hasAnyRowBeenReplicated; - return myState; - } - - @Override - public void restore(Object state, BlockEncodingSerdeProvider serdeProvider) - { - PagePartitionerState myState = (PagePartitionerState) state; - this.rowsAdded.set(myState.rowsAdded); - this.pagesAdded.set(myState.pagesAdded); - this.hasAnyRowBeenReplicated = myState.hasAnyRowBeenReplicated; - } - - private static class PagePartitionerState - implements Serializable - { - private long rowsAdded; - private long pagesAdded; - private boolean hasAnyRowBeenReplicated; - } - } - public static class PartitionedOutputInfo implements Mergeable, OperatorInfo { diff --git a/presto-main/src/main/java/io/prestosql/operator/output/PositionsAppender.java b/presto-main/src/main/java/io/prestosql/operator/output/PositionsAppender.java new file mode 100644 index 000000000..1b04f6508 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/output/PositionsAppender.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.output; + +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.type.Type; +import it.unimi.dsi.fastutil.ints.IntArrayList; + +import static java.util.Objects.requireNonNull; + +public interface PositionsAppender +{ + void appendTo(IntArrayList positions, Block source, BlockBuilder target); + + class TypedPositionsAppender + implements PositionsAppender + { + private final Type type; + + public TypedPositionsAppender(Type type) + { + this.type = requireNonNull(type, "type is null"); + } + + @Override + public void appendTo(IntArrayList positions, Block source, BlockBuilder target) + { + int[] positionArray = positions.elements(); + for (int i = 0; i < positions.size(); i++) { + type.appendTo(source, positionArray[i], target); + } + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/operator/output/PositionsAppenderFactory.java b/presto-main/src/main/java/io/prestosql/operator/output/PositionsAppenderFactory.java new file mode 100644 index 000000000..e6565b55e --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/output/PositionsAppenderFactory.java @@ -0,0 +1,347 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.output; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import io.airlift.bytecode.DynamicClassLoader; +import io.prestosql.operator.output.PositionsAppender.TypedPositionsAppender; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.Int128ArrayBlock; +import io.prestosql.spi.type.FixedWidthType; +import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.VariableWidthType; +import io.prestosql.sql.gen.IsolatedClass; +import it.unimi.dsi.fastutil.ints.IntArrayList; + +import java.util.Objects; +import java.util.Optional; + +import static io.airlift.slice.SizeOf.SIZE_OF_LONG; +import static java.util.Objects.requireNonNull; + +/** + * Isolates the {@code PositionsAppender} class per type and block tuples. + * Type specific {@code PositionsAppender} implementations manually inline {@code Type#appendTo} method inside the loop + * to avoid virtual(mega-morphic) calls and force jit to inline the {@code Block} and {@code BlockBuilder} methods. + * Ideally, {@code TypedPositionsAppender} could work instead of type specific {@code PositionsAppender}s, + * but in practice jit falls back to virtual calls in some cases (e.g. {@link Block#isNull}). + */ +public class PositionsAppenderFactory +{ + private final LoadingCache cache; + + public PositionsAppenderFactory() + { + this.cache = CacheBuilder.newBuilder().maximumSize(1000).build(CacheLoader.from(key -> createAppender(key.type))); + } + + public PositionsAppender create(Type type, Class blockClass) + { + return cache.getUnchecked(new CacheKey(type, blockClass)); + } + + private PositionsAppender createAppender(Type type) + { + return Optional.ofNullable(findDedicatedAppenderClassFor(type)) + .map(this::isolateAppender) + .orElseGet(() -> isolateTypeAppender(type)); + } + + private Class findDedicatedAppenderClassFor(Type type) + { + if (type instanceof FixedWidthType) { + switch (((FixedWidthType) type).getFixedSize()) { + case Byte.BYTES: + return BytePositionsAppender.class; + case Short.BYTES: + return SmallintPositionsAppender.class; + case Integer.BYTES: + return IntPositionsAppender.class; + case Long.BYTES: + return LongPositionsAppender.class; + case Int128ArrayBlock.INT128_BYTES: + return Int128PositionsAppender.class; + default: + // size not supported directly, fallback to the generic appender + } + } + else if (type instanceof VariableWidthType) { + return SlicePositionsAppender.class; + } + + return null; + } + + private PositionsAppender isolateTypeAppender(Type type) + { + Class isolatedAppenderClass = isolateAppenderClass(TypedPositionsAppender.class); + try { + return isolatedAppenderClass.getConstructor(Type.class).newInstance(type); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } + + private PositionsAppender isolateAppender(Class appenderClass) + { + Class isolatedAppenderClass = isolateAppenderClass(appenderClass); + try { + return isolatedAppenderClass.getConstructor().newInstance(); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } + + private Class isolateAppenderClass(Class appenderClass) + { + DynamicClassLoader dynamicClassLoader = new DynamicClassLoader(PositionsAppender.class.getClassLoader()); + + Class isolatedBatchPositionsTransferClass = IsolatedClass.isolateClass( + dynamicClassLoader, + PositionsAppender.class, + appenderClass); + return isolatedBatchPositionsTransferClass; + } + + public static class LongPositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeLong(block.getLong(position, 0)).closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + blockBuilder.writeLong(block.getLong(positionArray[i], 0)).closeEntry(); + } + } + } + } + + public static class IntPositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeInt(block.getInt(position, 0)).closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + blockBuilder.writeInt(block.getInt(positionArray[i], 0)).closeEntry(); + } + } + } + } + + public static class BytePositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeByte(block.getByte(position, 0)).closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + blockBuilder.writeByte(block.getByte(positionArray[i], 0)).closeEntry(); + } + } + } + } + + public static class SlicePositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + block.writeBytesTo(position, 0, block.getSliceLength(position), blockBuilder); + blockBuilder.closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + block.writeBytesTo(position, 0, block.getSliceLength(position), blockBuilder); + blockBuilder.closeEntry(); + } + } + } + } + + public static class SmallintPositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeShort(block.getShort(position, 0)).closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + blockBuilder.writeShort(block.getShort(positionArray[i], 0)).closeEntry(); + } + } + } + } + + public static class Int96PositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeLong(block.getLong(position, 0)); + blockBuilder.writeInt(block.getInt(position, SIZE_OF_LONG)); + blockBuilder.closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + blockBuilder.writeLong(block.getLong(position, 0)); + blockBuilder.writeInt(block.getInt(position, SIZE_OF_LONG)); + blockBuilder.closeEntry(); + } + } + } + } + + public static class Int128PositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeLong(block.getLong(position, 0)); + blockBuilder.writeLong(block.getLong(position, SIZE_OF_LONG)); + blockBuilder.closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + blockBuilder.writeLong(block.getLong(position, 0)); + blockBuilder.writeLong(block.getLong(position, SIZE_OF_LONG)); + blockBuilder.closeEntry(); + } + } + } + } + + private static class CacheKey + { + private final Type type; + private final Class blockClass; + + private CacheKey(Type type, Class blockClass) + { + this.type = requireNonNull(type, "type is null"); + this.blockClass = requireNonNull(blockClass, "blockClass is null"); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CacheKey cacheKey = (CacheKey) o; + return type.equals(cacheKey.type) && blockClass.equals(cacheKey.blockClass); + } + + @Override + public int hashCode() + { + return Objects.hash(type, blockClass); + } + } +} diff --git a/presto-main/src/main/java/io/prestosql/query/CachedSqlQueryExecution.java b/presto-main/src/main/java/io/prestosql/query/CachedSqlQueryExecution.java index 10543b226..7c4b2040c 100644 --- a/presto-main/src/main/java/io/prestosql/query/CachedSqlQueryExecution.java +++ b/presto-main/src/main/java/io/prestosql/query/CachedSqlQueryExecution.java @@ -33,7 +33,6 @@ import io.prestosql.execution.SqlQueryExecution; import io.prestosql.execution.SqlTaskManager; import io.prestosql.execution.TableExecuteContextManager; -import io.prestosql.execution.scheduler.ExecutionPolicy; import io.prestosql.execution.scheduler.NodeAllocatorService; import io.prestosql.execution.scheduler.NodeScheduler; import io.prestosql.execution.scheduler.PartitionMemoryEstimatorFactory; @@ -41,6 +40,7 @@ import io.prestosql.execution.scheduler.TaskDescriptorStorage; import io.prestosql.execution.scheduler.TaskExecutionStats; import io.prestosql.execution.scheduler.TaskSourceFactory; +import io.prestosql.execution.scheduler.policy.ExecutionPolicy; import io.prestosql.execution.warnings.WarningCollector; import io.prestosql.failuredetector.FailureDetector; import io.prestosql.heuristicindex.HeuristicIndexerManager; diff --git a/presto-main/src/main/java/io/prestosql/server/CoordinatorModule.java b/presto-main/src/main/java/io/prestosql/server/CoordinatorModule.java index a1a84ec30..fc8580ada 100644 --- a/presto-main/src/main/java/io/prestosql/server/CoordinatorModule.java +++ b/presto-main/src/main/java/io/prestosql/server/CoordinatorModule.java @@ -99,20 +99,21 @@ import io.prestosql.execution.resourcegroups.InternalResourceGroupManager; import io.prestosql.execution.resourcegroups.LegacyResourceGroupConfigurationManager; import io.prestosql.execution.resourcegroups.ResourceGroupManager; -import io.prestosql.execution.scheduler.AllAtOnceExecutionPolicy; import io.prestosql.execution.scheduler.BinPackingNodeAllocatorService; import io.prestosql.execution.scheduler.ConstantPartitionMemoryEstimator; -import io.prestosql.execution.scheduler.ExecutionPolicy; import io.prestosql.execution.scheduler.FixedCountNodeAllocatorService; import io.prestosql.execution.scheduler.NodeAllocatorService; import io.prestosql.execution.scheduler.NodeSchedulerConfig; import io.prestosql.execution.scheduler.PartitionMemoryEstimatorFactory; -import io.prestosql.execution.scheduler.PhasedExecutionPolicy; import io.prestosql.execution.scheduler.SplitSchedulerStats; import io.prestosql.execution.scheduler.StageTaskSourceFactory; import io.prestosql.execution.scheduler.TaskDescriptorStorage; import io.prestosql.execution.scheduler.TaskExecutionStats; import io.prestosql.execution.scheduler.TaskSourceFactory; +import io.prestosql.execution.scheduler.policy.AllAtOnceExecutionPolicy; +import io.prestosql.execution.scheduler.policy.ExecutionPolicy; +import io.prestosql.execution.scheduler.policy.PhasedExecutionPolicy; +import io.prestosql.execution.scheduler.policy.PrioritizeUtilizationExecutionPolicy; import io.prestosql.failuredetector.CoordinatorGossipFailureDetectorModule; import io.prestosql.failuredetector.FailureDetectorModule; import io.prestosql.memory.ClusterMemoryManager; @@ -451,6 +452,7 @@ else if (nodeSchedulerConfig.getNodeAllocatorType() == FIXED_COUNT) { MapBinder executionPolicyBinder = newMapBinder(binder, String.class, ExecutionPolicy.class); executionPolicyBinder.addBinding("all-at-once").to(AllAtOnceExecutionPolicy.class); executionPolicyBinder.addBinding("phased").to(PhasedExecutionPolicy.class); + executionPolicyBinder.addBinding("prioritize-utilization").to(PrioritizeUtilizationExecutionPolicy.class); binder.bind(TaskSourceFactory.class).to(StageTaskSourceFactory.class).in(Scopes.SINGLETON); binder.bind(TaskDescriptorStorage.class).in(Scopes.SINGLETON); diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/FeaturesConfig.java index 0acdf6ff8..3f9c03916 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/FeaturesConfig.java @@ -30,6 +30,7 @@ import javax.validation.constraints.AssertTrue; import javax.validation.constraints.DecimalMax; import javax.validation.constraints.DecimalMin; +import javax.validation.constraints.Max; import javax.validation.constraints.Min; import javax.validation.constraints.NotNull; @@ -184,6 +185,13 @@ public class FeaturesConfig private boolean skipNonApplicableRulesEnabled; private boolean prioritizeLargerSpiltsMemoryRevoke = true; private DataSize revocableMemorySelectionThreshold = new DataSize(512, MEGABYTE); + private double joinMultiClauseIndependenceFactor = 0.25; + private double filterConjunctionIndependenceFactor = 0.75; + + // adaptive partial aggregation + private boolean adaptivePartialAggregationEnabled = true; + private long adaptivePartialAggregationMinRows = 100_000; + private double adaptivePartialAggregationUniqueRowsRatioThreshold = 0.8; public enum JoinReorderingStrategy { @@ -1513,4 +1521,72 @@ public FeaturesConfig setRevocableMemorySelectionThreshold(DataSize revocableMem this.revocableMemorySelectionThreshold = revocableMemorySelectionThreshold; return this; } + + @Min(0) + @Max(1) + public double getJoinMultiClauseIndependenceFactor() + { + return joinMultiClauseIndependenceFactor; + } + + @Config("optimizer.join-multi-clause-independence-factor") + @ConfigDescription("Scales the strength of independence assumption for selectivity estimates of multi-clause joins") + public FeaturesConfig setJoinMultiClauseIndependenceFactor(double joinMultiClauseIndependenceFactor) + { + this.joinMultiClauseIndependenceFactor = joinMultiClauseIndependenceFactor; + return this; + } + + @Min(0) + @Max(1) + public double getFilterConjunctionIndependenceFactor() + { + return filterConjunctionIndependenceFactor; + } + + @Config("optimizer.filter-conjunction-independence-factor") + @ConfigDescription("Scales the strength of independence assumption for selectivity estimates of the conjunction of multiple filters") + public FeaturesConfig setFilterConjunctionIndependenceFactor(double filterConjunctionIndependenceFactor) + { + this.filterConjunctionIndependenceFactor = filterConjunctionIndependenceFactor; + return this; + } + + public boolean isAdaptivePartialAggregationEnabled() + { + return adaptivePartialAggregationEnabled; + } + + @Config("adaptive-partial-aggregation.enabled") + public FeaturesConfig setAdaptivePartialAggregationEnabled(boolean adaptivePartialAggregationEnabled) + { + this.adaptivePartialAggregationEnabled = adaptivePartialAggregationEnabled; + return this; + } + + public long getAdaptivePartialAggregationMinRows() + { + return adaptivePartialAggregationMinRows; + } + + @Config("adaptive-partial-aggregation.min-rows") + @ConfigDescription("Minimum number of processed rows before partial aggregation might be adaptively turned off") + public FeaturesConfig setAdaptivePartialAggregationMinRows(long adaptivePartialAggregationMinRows) + { + this.adaptivePartialAggregationMinRows = adaptivePartialAggregationMinRows; + return this; + } + + public double getAdaptivePartialAggregationUniqueRowsRatioThreshold() + { + return adaptivePartialAggregationUniqueRowsRatioThreshold; + } + + @Config("adaptive-partial-aggregation.unique-rows-ratio-threshold") + @ConfigDescription("Ratio between aggregation output and input rows above which partial aggregation might be adaptively turned off") + public FeaturesConfig setAdaptivePartialAggregationUniqueRowsRatioThreshold(double adaptivePartialAggregationUniqueRowsRatioThreshold) + { + this.adaptivePartialAggregationUniqueRowsRatioThreshold = adaptivePartialAggregationUniqueRowsRatioThreshold; + return this; + } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java index 64bf99f82..64b480e3f 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/LocalExecutionPlanner.java @@ -83,7 +83,6 @@ import io.prestosql.operator.PagesSpatialIndexFactory; import io.prestosql.operator.PartitionFunction; import io.prestosql.operator.PartitionedLookupSourceFactory; -import io.prestosql.operator.PartitionedOutputOperator.PartitionedOutputFactory; import io.prestosql.operator.PipelineExecutionStrategy; import io.prestosql.operator.RowNumberOperator; import io.prestosql.operator.ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory; @@ -114,6 +113,7 @@ import io.prestosql.operator.aggregation.AccumulatorFactory; import io.prestosql.operator.aggregation.InternalAggregationFunction; import io.prestosql.operator.aggregation.LambdaProvider; +import io.prestosql.operator.aggregation.partial.PartialAggregationController; import io.prestosql.operator.dynamicfilter.CrossRegionDynamicFilterOperator; import io.prestosql.operator.exchange.LocalExchange.LocalExchangeFactory; import io.prestosql.operator.exchange.LocalExchangeSinkOperator.LocalExchangeSinkOperatorFactory; @@ -126,6 +126,8 @@ import io.prestosql.operator.index.IndexJoinLookupStats; import io.prestosql.operator.index.IndexLookupSourceFactory; import io.prestosql.operator.index.IndexSourceOperator; +import io.prestosql.operator.output.PartitionedOutputOperator.PartitionedOutputFactory; +import io.prestosql.operator.output.PositionsAppenderFactory; import io.prestosql.operator.project.CursorProcessor; import io.prestosql.operator.project.PageProcessor; import io.prestosql.operator.window.FrameInfo; @@ -261,6 +263,8 @@ import static com.google.common.collect.Range.closedOpen; import static io.airlift.concurrent.MoreFutures.addSuccessCallback; import static io.airlift.units.DataSize.Unit.BYTE; +import static io.prestosql.SystemSessionProperties.getAdaptivePartialAggregationMinRows; +import static io.prestosql.SystemSessionProperties.getAdaptivePartialAggregationUniqueRowsRatioThreshold; import static io.prestosql.SystemSessionProperties.getAggregationOperatorUnspillMemoryLimit; import static io.prestosql.SystemSessionProperties.getCteMaxPrefetchQueueSize; import static io.prestosql.SystemSessionProperties.getCteMaxQueueSize; @@ -272,6 +276,7 @@ import static io.prestosql.SystemSessionProperties.getSpillOperatorThresholdReuseExchange; import static io.prestosql.SystemSessionProperties.getTaskConcurrency; import static io.prestosql.SystemSessionProperties.getTaskWriterCount; +import static io.prestosql.SystemSessionProperties.isAdaptivePartialAggregationEnabled; import static io.prestosql.SystemSessionProperties.isCTEReuseEnabled; import static io.prestosql.SystemSessionProperties.isCrossRegionDynamicFilterEnabled; import static io.prestosql.SystemSessionProperties.isEnableDynamicFiltering; @@ -381,6 +386,7 @@ public class LocalExecutionPlanner protected final TaskManagerConfig taskManagerConfig; private final ExchangeManagerRegistry exchangeManagerRegistry; protected final TableExecuteContextManager tableExecuteContextManager; + private final PositionsAppenderFactory positionsAppenderFactory = new PositionsAppenderFactory(); public Metadata getMetadata() { @@ -682,7 +688,8 @@ public LocalExecutionPlan plan( partitioningScheme.isReplicateNullsAndAny(), nullChannel, outputBuffer, - maxPagePartitioningBufferSize), + maxPagePartitioningBufferSize, + positionsAppenderFactory), feederCTEId, feederCTEParentId, cteCtx); @@ -3631,7 +3638,8 @@ private OperatorFactory createHashAggregationOperatorFactory( unspillMemoryLimit, spillerFactory, joinCompiler, - useSystemMemory); + useSystemMemory, + createPartialAggregationController(step, session)); } } @@ -3683,7 +3691,8 @@ private OperatorFactory createSortAggregationOperatorFactory( spillerFactory, joinCompiler, useSystemMemory, - finalizeSymbol.isPresent() ? true : false); + finalizeSymbol.isPresent() ? true : false, + createPartialAggregationController(step, session)); } private Optional getOutputMappingAndGroupIdChannel(Map aggregations, @@ -3735,6 +3744,15 @@ private Optional getOutputMappingAndGroupIdChannel(Map createPartialAggregationController(AggregationNode.Step step, Session session) + { + return step.isOutputPartial() && isAdaptivePartialAggregationEnabled(session) ? + Optional.of(new PartialAggregationController( + getAdaptivePartialAggregationMinRows(session), + getAdaptivePartialAggregationUniqueRowsRatioThreshold(session))) : + Optional.empty(); + } + private static TableFinisher createTableFinisher(Session session, TableFinishNode node, Metadata metadata) { WriterTarget target = node.getTarget(); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java index ea2ae77fe..c2ba93b2e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java @@ -336,11 +336,6 @@ private boolean hasHighSelectivity(PlanNode node) } } - // If no predicate at all, the selectivity will be high - if (node instanceof TableScanNode && ((TableScanNode) buildSideTableScanNode.get()).getEnforcedConstraint().isAll()) { - return true; - } - Estimate totalRowCount = metadata.getTableStatistics(session, ((TableScanNode) buildSideTableScanNode.get()).getTable(), Constraint.alwaysTrue(), true).getRowCount(); PlanNodeStatsEstimate filteredStats = statsProvider.getStats(node); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PlanNodeSearcher.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PlanNodeSearcher.java index ba15514db..3b668f0ed 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PlanNodeSearcher.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PlanNodeSearcher.java @@ -23,10 +23,12 @@ import java.util.function.Predicate; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Predicates.alwaysFalse; import static com.google.common.base.Predicates.alwaysTrue; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static io.prestosql.sql.planner.iterative.Lookup.noLookup; +import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; public class PlanNodeSearcher @@ -57,6 +59,21 @@ private PlanNodeSearcher(PlanNode node, Lookup lookup) this.lookup = requireNonNull(lookup, "lookup is null"); } + @SafeVarargs + public final PlanNodeSearcher whereIsInstanceOfAny(Class... classes) + { + return whereIsInstanceOfAny(asList(classes)); + } + + public final PlanNodeSearcher whereIsInstanceOfAny(List> classes) + { + Predicate predicate = alwaysFalse(); + for (Class clazz : classes) { + predicate = predicate.or(clazz::isInstance); + } + return where(predicate); + } + public PlanNodeSearcher where(Predicate where) { this.where = requireNonNull(where, "where is null"); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneCTENodes.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneCTENodes.java index fc3b4a6a7..8ef3ae640 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneCTENodes.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/PruneCTENodes.java @@ -106,13 +106,16 @@ private OptimizedPlanRewriter(Metadata metadata, TypeAnalyzer typeAnalyzer, Bool @Override public PlanNode visitJoin(JoinNode node, RewriteContext context) { + Integer left = getChildCTERefNum(node.getLeft()); + Integer right = getChildCTERefNum(node.getRight()); if (pruneCTEWithCrossJoin && node.isCrossJoin()) { - Integer left = getChildCTERefNum(node.getLeft()); - Integer right = getChildCTERefNum(node.getRight()); if (left != null && right != null && left.equals(right)) { cTEWithCrossJoinList.add(left); } } + if (left != null && right != null && left.equals(right)) { + cteToPrune.add(left); + } return context.defaultRewrite(node, context.get()); } @@ -160,7 +163,7 @@ public PlanNode visitCTEScan(CTEScanNode inputNode, RewriteContext c cteUsageMap.merge(commonCTERefNum, 1, Integer::sum); } else { - if (cteUsageMap.get(commonCTERefNum) == 1 || cteToPrune.contains(commonCTERefNum)) { + if (cteUsageMap.get(commonCTERefNum) == 1 || (cteToPrune.contains(commonCTERefNum) && cteUsageMap.get(commonCTERefNum) > 3)) { node = (CTEScanNode) visitPlan(node, context); return node.getSource(); } diff --git a/presto-main/src/main/java/io/prestosql/util/MoreMath.java b/presto-main/src/main/java/io/prestosql/util/MoreMath.java index 00365d868..52837b4f5 100644 --- a/presto-main/src/main/java/io/prestosql/util/MoreMath.java +++ b/presto-main/src/main/java/io/prestosql/util/MoreMath.java @@ -15,6 +15,7 @@ import java.util.stream.DoubleStream; +import static java.lang.Double.NaN; import static java.lang.Double.isNaN; public final class MoreMath @@ -110,4 +111,37 @@ public static double firstNonNaN(double... values) } throw new IllegalArgumentException("All values are NaN"); } + + public static double averageExcludingNaNs(double first, double second) + { + if (isNaN(first) && isNaN(second)) { + return NaN; + } + if (!isNaN(first) && !isNaN(second)) { + return (first + second) / 2; + } + return firstNonNaN(first, second); + } + + public static double minExcludeNaN(double v1, double v2) + { + if (isNaN(v1)) { + return v2; + } + if (isNaN(v2)) { + return v1; + } + return min(v1, v2); + } + + public static double maxExcludeNaN(double v1, double v2) + { + if (isNaN(v1)) { + return v2; + } + if (isNaN(v2)) { + return v1; + } + return max(v1, v2); + } } diff --git a/presto-main/src/test/java/io/prestosql/execution/scheduler/TestNodeScheduler.java b/presto-main/src/test/java/io/prestosql/execution/scheduler/TestNodeScheduler.java index 62eb67980..3257911d7 100644 --- a/presto-main/src/test/java/io/prestosql/execution/scheduler/TestNodeScheduler.java +++ b/presto-main/src/test/java/io/prestosql/execution/scheduler/TestNodeScheduler.java @@ -107,8 +107,8 @@ import static io.prestosql.SessionTestUtils.TEST_SESSION_REUSE; import static io.prestosql.execution.SqlStageExecution.createSqlStageExecution; import static io.prestosql.execution.scheduler.NetworkLocation.ROOT_LOCATION; -import static io.prestosql.execution.scheduler.TestPhasedExecutionSchedule.createTableScanPlanFragment; import static io.prestosql.execution.scheduler.TestSourcePartitionedScheduler.createFixedSplitSource; +import static io.prestosql.execution.scheduler.policy.TestPhasedExecutionSchedule.createTableScanPlanFragment; import static io.prestosql.spi.StandardErrorCode.NO_NODES_AVAILABLE; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.testing.TestingRecoveryUtils.NOOP_RECOVERY_UTILS; diff --git a/presto-main/src/test/java/io/prestosql/execution/scheduler/TestSourcePartitionedScheduler.java b/presto-main/src/test/java/io/prestosql/execution/scheduler/TestSourcePartitionedScheduler.java index 9c03631d9..20e9e6699 100644 --- a/presto-main/src/test/java/io/prestosql/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/presto-main/src/test/java/io/prestosql/execution/scheduler/TestSourcePartitionedScheduler.java @@ -339,7 +339,8 @@ public void testNoNodes() Iterables.getOnlyElement(plan.getSplitSources().keySet()), Iterables.getOnlyElement(plan.getSplitSources().values()), new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(CONNECTOR_ID, false, null), stage::getAllTasks), - 2, session, new HeuristicIndexerManager(new FileSystemClientManager(), new HetuMetaStoreManager()), new TableExecuteContextManager()); + 2, session, new HeuristicIndexerManager(new FileSystemClientManager(), new HetuMetaStoreManager()), new TableExecuteContextManager(), + new DynamicFilterService(new LocalStateStoreProvider(new SeedStoreManager(new FileSystemClientManager())))); scheduler.schedule(); }).hasErrorCode(NO_NODES_AVAILABLE); } @@ -477,7 +478,8 @@ private static StageScheduler getSourcePartitionedScheduler( SplitSource splitSource = Iterables.getOnlyElement(plan.getSplitSources().values()); SplitPlacementPolicy placementPolicy = new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(splitSource.getCatalogName(), false, null), stage::getAllTasks); return newSourcePartitionedSchedulerAsStageScheduler(stage, sourceNode, splitSource, - placementPolicy, splitBatchSize, session, new HeuristicIndexerManager(new FileSystemClientManager(), new HetuMetaStoreManager()), new TableExecuteContextManager()); + placementPolicy, splitBatchSize, session, new HeuristicIndexerManager(new FileSystemClientManager(), new HetuMetaStoreManager()), new TableExecuteContextManager(), + new DynamicFilterService(new LocalStateStoreProvider(new SeedStoreManager(new FileSystemClientManager())))); } private static StageExecutionPlan createPlan(ConnectorSplitSource splitSource) diff --git a/presto-main/src/test/java/io/prestosql/execution/scheduler/TestPhasedExecutionSchedule.java b/presto-main/src/test/java/io/prestosql/execution/scheduler/policy/TestPhasedExecutionSchedule.java similarity index 99% rename from presto-main/src/test/java/io/prestosql/execution/scheduler/TestPhasedExecutionSchedule.java rename to presto-main/src/test/java/io/prestosql/execution/scheduler/policy/TestPhasedExecutionSchedule.java index 7f408e6e5..7c746ef30 100644 --- a/presto-main/src/test/java/io/prestosql/execution/scheduler/TestPhasedExecutionSchedule.java +++ b/presto-main/src/test/java/io/prestosql/execution/scheduler/policy/TestPhasedExecutionSchedule.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.prestosql.execution.scheduler; +package io.prestosql.execution.scheduler.policy; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; diff --git a/presto-main/src/test/java/io/prestosql/operator/BenchmarkHashAndStreamingAggregationOperators.java b/presto-main/src/test/java/io/prestosql/operator/BenchmarkHashAndStreamingAggregationOperators.java index 47bcb3168..55c69f0de 100644 --- a/presto-main/src/test/java/io/prestosql/operator/BenchmarkHashAndStreamingAggregationOperators.java +++ b/presto-main/src/test/java/io/prestosql/operator/BenchmarkHashAndStreamingAggregationOperators.java @@ -180,7 +180,8 @@ private OperatorFactory createHashAggregationOperatorFactory(Optional h succinctBytes(Integer.MAX_VALUE), spillerFactory, joinCompiler, - false); + false, + Optional.empty()); } private static void repeatToStringBlock(String value, int count, BlockBuilder blockBuilder) diff --git a/presto-main/src/test/java/io/prestosql/operator/BenchmarkPartitionedOutputOperator.java b/presto-main/src/test/java/io/prestosql/operator/BenchmarkPartitionedOutputOperator.java index 18970b94b..d381e0041 100644 --- a/presto-main/src/test/java/io/prestosql/operator/BenchmarkPartitionedOutputOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/BenchmarkPartitionedOutputOperator.java @@ -19,8 +19,10 @@ import io.prestosql.execution.buffer.OutputBuffers; import io.prestosql.execution.buffer.PartitionedOutputBuffer; import io.prestosql.memory.context.SimpleLocalMemoryContext; -import io.prestosql.operator.PartitionedOutputOperator.PartitionedOutputFactory; import io.prestosql.operator.exchange.LocalPartitionGenerator; +import io.prestosql.operator.output.PartitionedOutputOperator; +import io.prestosql.operator.output.PartitionedOutputOperator.PartitionedOutputFactory; +import io.prestosql.operator.output.PositionsAppenderFactory; import io.prestosql.spi.Page; import io.prestosql.spi.PageBuilder; import io.prestosql.spi.block.BlockBuilder; @@ -75,6 +77,8 @@ @BenchmarkMode(Mode.AverageTime) public class BenchmarkPartitionedOutputOperator { + private static final PositionsAppenderFactory POSITIONS_APPENDER_FACTORY = new PositionsAppenderFactory(); + @Benchmark public void addPage(BenchmarkData data) { @@ -126,7 +130,8 @@ private PartitionedOutputOperator createPartitionedOutputOperator() false, OptionalInt.empty(), buffer, - new DataSize(1, GIGABYTE)); + new DataSize(1, GIGABYTE), + POSITIONS_APPENDER_FACTORY); TaskContext taskContext = createTaskContext(); return (PartitionedOutputOperator) operatorFactory .createOutputOperator(0, new PlanNodeId("plan-node-0"), TYPES, Function.identity(), taskContext) diff --git a/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java index 9c751c56b..8d2af0016 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java @@ -27,6 +27,7 @@ import io.prestosql.operator.aggregation.InternalAggregationFunction; import io.prestosql.operator.aggregation.builder.AggregationBuilder; import io.prestosql.operator.aggregation.builder.InMemoryHashAggregationBuilder; +import io.prestosql.operator.aggregation.partial.PartialAggregationController; import io.prestosql.spi.Page; import io.prestosql.spi.block.BlockBuilder; import io.prestosql.spi.block.PageBuilderStatus; @@ -65,11 +66,14 @@ import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder; import static io.airlift.testing.Assertions.assertGreaterThan; +import static io.airlift.units.DataSize.Unit.BYTE; import static io.airlift.units.DataSize.Unit.KILOBYTE; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.airlift.units.DataSize.succinctBytes; import static io.prestosql.RowPagesBuilder.rowPagesBuilder; import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.block.BlockAssertions.createLongsBlock; +import static io.prestosql.block.BlockAssertions.createRLEBlock; import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; import static io.prestosql.operator.GroupByHashYieldAssertion.GroupByHashYieldResult; import static io.prestosql.operator.GroupByHashYieldAssertion.createPagesWithDistinctHashKeys; @@ -81,6 +85,7 @@ import static io.prestosql.operator.OperatorAssertion.toPages; import static io.prestosql.operator.OperatorAssertion.toPagesCompareStateSimple; import static io.prestosql.spi.function.FunctionKind.AGGREGATE; +import static io.prestosql.spi.plan.AggregationNode.Step.PARTIAL; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.BooleanType.BOOLEAN; import static io.prestosql.spi.type.DoubleType.DOUBLE; @@ -109,6 +114,8 @@ public class TestHashAggregationOperator new Signature(QualifiedObjectName.valueOfDefaultFunction("sum"), AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature())); private static final InternalAggregationFunction COUNT = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation( new Signature(QualifiedObjectName.valueOfDefaultFunction("count"), AGGREGATE, BIGINT.getTypeSignature())); + private static final InternalAggregationFunction LONG_MIN = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation( + new Signature(QualifiedObjectName.valueOfDefaultFunction("min"), AGGREGATE, BIGINT.getTypeSignature())); private static final int MAX_BLOCK_SIZE_IN_BYTES = 64 * 1024; @@ -203,7 +210,8 @@ public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, boole succinctBytes(memoryLimitForMergeWithMemory), spillerFactory, joinCompiler, - false); + false, + Optional.empty()); DriverContext driverContext = createDriverContext(memoryLimitForMerge); @@ -265,7 +273,8 @@ public void testHashAggregationSnapshot() succinctBytes(memoryLimitForMergeWithMemory), spillerFactory, joinCompiler, - false); + false, + Optional.empty()); DriverContext driverContext = createDriverContext(memoryLimitForMerge); @@ -488,7 +497,8 @@ public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEna succinctBytes(memoryLimitForMergeWithMemory), spillerFactory, joinCompiler, - false); + false, + Optional.empty()); DriverContext driverContext = createDriverContext(memoryLimitForMerge); MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, BIGINT, BIGINT, BIGINT, DOUBLE, VARCHAR, BIGINT, BIGINT) @@ -536,7 +546,8 @@ public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean sp succinctBytes(memoryLimitForMergeWithMemory), spillerFactory, joinCompiler, - false); + false, + Optional.empty()); Operator operator = operatorFactory.createOperator(driverContext); toPages(operator, input.iterator(), revokeMemoryWhenAddingPages); @@ -578,7 +589,8 @@ public void testMemoryLimit(boolean hashEnabled) 100_000, Optional.of(new DataSize(16, MEGABYTE)), joinCompiler, - false); + false, + Optional.empty()); toPages(operatorFactory, driverContext, input); } @@ -618,7 +630,8 @@ public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, boo succinctBytes(memoryLimitForMergeWithMemory), spillerFactory, joinCompiler, - false); + false, + Optional.empty()); toPages(operatorFactory, driverContext, input, revokeMemoryWhenAddingPages); } @@ -640,7 +653,8 @@ public void testMemoryReservationYield(Type type) 1, Optional.of(new DataSize(16, MEGABYTE)), joinCompiler, - false); + false, + Optional.empty()); // get result with yield; pick a relatively small buffer for aggregator's memory usage GroupByHashYieldResult result; @@ -692,7 +706,8 @@ public void testHashBuilderResizeLimit(boolean hashEnabled) 100_000, Optional.of(new DataSize(16, MEGABYTE)), joinCompiler, - false); + false, + Optional.empty()); toPages(operatorFactory, driverContext, input); } @@ -726,7 +741,8 @@ public void testMultiSliceAggregationOutput(boolean hashEnabled) 100_000, Optional.of(new DataSize(16, MEGABYTE)), joinCompiler, - false); + false, + Optional.empty()); assertEquals(toPages(operatorFactory, createDriverContext(), input).size(), 2); } @@ -750,14 +766,15 @@ public void testMultiplePartialFlushes(boolean hashEnabled) ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), - Step.PARTIAL, + PARTIAL, ImmutableList.of(LONG_SUM.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, Optional.of(new DataSize(1, KILOBYTE)), joinCompiler, - true); + true, + Optional.empty()); DriverContext driverContext = createDriverContext(1024); @@ -843,7 +860,8 @@ public void testMergeWithMemorySpill() succinctBytes(Integer.MAX_VALUE), spillerFactory, joinCompiler, - false); + false, + Optional.empty()); DriverContext driverContext = createDriverContext(smallPagesSpillThresholdSize); @@ -898,7 +916,8 @@ public void testSpillerFailure() succinctBytes(Integer.MAX_VALUE), new FailingSpillerFactory(), joinCompiler, - false); + false, + Optional.empty()); try { toPages(operatorFactory, driverContext, input); @@ -939,7 +958,8 @@ private void testMemoryTracking(boolean useSystemMemory) 100_000, Optional.of(new DataSize(16, MEGABYTE)), joinCompiler, - useSystemMemory); + useSystemMemory, + Optional.empty()); DriverContext driverContext = createDriverContext(1024); @@ -963,6 +983,106 @@ private void testMemoryTracking(boolean useSystemMemory) assertEquals(driverContext.getMemoryUsage(), 0); } + @Test(dataProvider = "hashEnabled") + public void testAdaptivePartialAggregation() + { + List hashChannels = Ints.asList(0); + + PartialAggregationController partialAggregationController = new PartialAggregationController(5, 0.8); + HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory( + 0, + new PlanNodeId("test"), + ImmutableList.of(BIGINT), + hashChannels, + ImmutableList.of(), + PARTIAL, + ImmutableList.of(LONG_MIN.bind(ImmutableList.of(0), Optional.empty())), + Optional.empty(), + Optional.empty(), + 100, + Optional.of(new DataSize(1, BYTE)), // this setting makes operator to flush after each page + joinCompiler, + false, + // use 5 rows threshold to trigger adaptive partial aggregation after each page flush + Optional.of(partialAggregationController)); + + // at the start partial aggregation is enabled + assertFalse(partialAggregationController.isPartialAggregationDisabled()); + // First operator will trigger adaptive partial aggregation after the first page + List operator1Input = rowPagesBuilder(false, hashChannels, BIGINT) + .addBlocksPage(createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8, 8)) // first page will be hashed but the values are almost unique, so it will trigger adaptation + .addBlocksPage(createRLEBlock(1, 10)) // second page would be hashed to existing value 1. but if adaptive PA kicks in, the raw values will be passed on + .build(); + List operator1Expected = rowPagesBuilder(BIGINT, BIGINT) + .addBlocksPage(createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8), createLongsBlock(0, 1, 2, 3, 4, 5, 6, 7, 8)) // the last position was aggregated + .addBlocksPage(createRLEBlock(1, 10), createRLEBlock(1, 10)) // we are expecting second page with raw values + .build(); + OperatorAssertion.assertOperatorEquals(operatorFactory, ImmutableList.of(BIGINT), createDriverContext(), operator1Input, operator1Expected); + + // the first operator flush disables partial aggregation + assertTrue(partialAggregationController.isPartialAggregationDisabled()); + // second operator using the same factory, reuses PartialAggregationControl, so it will only produce raw pages (partial aggregation is disabled at this point) + List operator2Input = rowPagesBuilder(false, hashChannels, BIGINT) + .addBlocksPage(createRLEBlock(1, 10)) + .addBlocksPage(createRLEBlock(2, 10)) + .build(); + List operator2Expected = rowPagesBuilder(BIGINT, BIGINT) + .addBlocksPage(createRLEBlock(1, 10), createRLEBlock(1, 10)) + .addBlocksPage(createRLEBlock(2, 10), createRLEBlock(2, 10)) + .build(); + + OperatorAssertion.assertOperatorEquals(operatorFactory, ImmutableList.of(BIGINT), createDriverContext(), operator2Input, operator2Expected); + } + + @Test + public void testAdaptivePartialAggregationTriggeredOnlyOnFlush() + { + List hashChannels = Ints.asList(0); + + PartialAggregationController partialAggregationController = new PartialAggregationController(5, 0.8); + HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory( + 0, + new PlanNodeId("test"), + ImmutableList.of(BIGINT), + hashChannels, + ImmutableList.of(), + PARTIAL, + ImmutableList.of(LONG_MIN.bind(ImmutableList.of(0), Optional.empty())), + Optional.empty(), + Optional.empty(), + 10, + Optional.of(new DataSize(16, MEGABYTE)), // this setting makes operator to flush only after all pages + joinCompiler, + false, + // use 5 rows threshold to trigger adaptive partial aggregation after each page flush + Optional.of(partialAggregationController)); + + List operator1Input = rowPagesBuilder(false, hashChannels, BIGINT) + .addSequencePage(10, 0) // first page are unique values, so it would trigger adaptation, but it won't because flush is not called + .addBlocksPage(createRLEBlock(1, 2)) // second page will be hashed to existing value 1 + .build(); + // the total unique ows ratio for the first operator will be 10/12 so > 0.8 (adaptive partial aggregation uniqueRowsRatioThreshold) + List operator1Expected = rowPagesBuilder(BIGINT, BIGINT) + .addSequencePage(10, 0, 0) // we are expecting second page to be squashed with the first + .build(); + OperatorAssertion.assertOperatorEquals(operatorFactory, ImmutableList.of(BIGINT), createDriverContext(), operator1Input, operator1Expected); + + // the first operator flush disables partial aggregation + assertTrue(partialAggregationController.isPartialAggregationDisabled()); + + // second operator using the same factory, reuses PartialAggregationControl, so it will only produce raw pages (partial aggregation is disabled at this point) + List operator2Input = rowPagesBuilder(false, hashChannels, BIGINT) + .addBlocksPage(createRLEBlock(1, 10)) + .addBlocksPage(createRLEBlock(2, 10)) + .build(); + List operator2Expected = rowPagesBuilder(BIGINT, BIGINT) + .addBlocksPage(createRLEBlock(1, 10), createRLEBlock(1, 10)) + .addBlocksPage(createRLEBlock(2, 10), createRLEBlock(2, 10)) + .build(); + + OperatorAssertion.assertOperatorEquals(operatorFactory, ImmutableList.of(BIGINT), createDriverContext(), operator2Input, operator2Expected); + } + private DriverContext createDriverContext() { return createDriverContext(Integer.MAX_VALUE); diff --git a/presto-main/src/test/java/io/prestosql/operator/TestOperatorStats.java b/presto-main/src/test/java/io/prestosql/operator/TestOperatorStats.java index 4f22d96dd..93bf85d00 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestOperatorStats.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestOperatorStats.java @@ -16,7 +16,7 @@ import io.airlift.json.JsonCodec; import io.airlift.units.DataSize; import io.airlift.units.Duration; -import io.prestosql.operator.PartitionedOutputOperator.PartitionedOutputInfo; +import io.prestosql.operator.output.PartitionedOutputOperator.PartitionedOutputInfo; import io.prestosql.spi.plan.PlanNodeId; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/io/prestosql/operator/TestPartitionedOutputOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestPartitionedOutputOperator.java index 936f060d1..3ec4b60de 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestPartitionedOutputOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestPartitionedOutputOperator.java @@ -20,6 +20,8 @@ import io.prestosql.execution.TaskId; import io.prestosql.execution.buffer.PartitionedOutputBuffer; import io.prestosql.operator.exchange.LocalPartitionGenerator; +import io.prestosql.operator.output.PartitionedOutputOperator; +import io.prestosql.operator.output.PositionsAppenderFactory; import io.prestosql.snapshot.RecoveryUtils; import io.prestosql.spi.Page; import io.prestosql.spi.plan.PlanNodeId; @@ -154,6 +156,7 @@ private Map createExpectedMappingAfterFinish() private PartitionedOutputOperator createPartitionedOutputOperator(RecoveryUtils recoveryUtils, PartitionedOutputBuffer buffer) { + PositionsAppenderFactory positionsAppenderFactory = new PositionsAppenderFactory(); PartitionFunction partitionFunction = new LocalPartitionGenerator(new InterpretedHashGenerator(ImmutableList.of(BIGINT), new int[] {0}), PARTITION_COUNT); PartitionedOutputOperator.PartitionedOutputFactory operatorFactory = new PartitionedOutputOperator.PartitionedOutputFactory( partitionFunction, @@ -162,7 +165,8 @@ private PartitionedOutputOperator createPartitionedOutputOperator(RecoveryUtils false, OptionalInt.empty(), buffer, - new DataSize(1, GIGABYTE)); + new DataSize(1, GIGABYTE), + positionsAppenderFactory); TaskContext taskContext = createTaskContext(recoveryUtils); return (PartitionedOutputOperator) operatorFactory .createOutputOperator(0, new PlanNodeId("plan-node-0"), TYPES, Function.identity(), taskContext) diff --git a/presto-main/src/test/java/io/prestosql/operator/TestPartitionedOutputOperatorFactory.java b/presto-main/src/test/java/io/prestosql/operator/TestPartitionedOutputOperatorFactory.java index d426a2ac1..33680f368 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestPartitionedOutputOperatorFactory.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestPartitionedOutputOperatorFactory.java @@ -15,7 +15,8 @@ import io.airlift.units.DataSize; import io.prestosql.execution.buffer.OutputBuffer; -import io.prestosql.operator.PartitionedOutputOperator.PartitionedOutputOperatorFactory; +import io.prestosql.operator.output.PartitionedOutputOperator.PartitionedOutputOperatorFactory; +import io.prestosql.operator.output.PositionsAppenderFactory; import io.prestosql.spi.plan.PlanNodeId; import org.testng.annotations.Test; @@ -32,6 +33,7 @@ public class TestPartitionedOutputOperatorFactory @Test public void testDuplicate() { + PositionsAppenderFactory positionsAppenderFactory = new PositionsAppenderFactory(); OutputBuffer outputBuffer = mock(OutputBuffer.class); PartitionedOutputOperatorFactory factory1 = new PartitionedOutputOperatorFactory( 1, @@ -44,7 +46,8 @@ public void testDuplicate() false, OptionalInt.empty(), outputBuffer, - DataSize.succinctBytes(1)); + DataSize.succinctBytes(1), + positionsAppenderFactory); OperatorFactory factory2 = factory1.duplicate(); OperatorFactory factory3 = factory1.duplicate(); OperatorFactory factory4 = factory2.duplicate(); diff --git a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestFeaturesConfig.java index 864909480..959b9871d 100644 --- a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestFeaturesConfig.java @@ -159,7 +159,10 @@ public void testDefaults() .setSkipAttachingStatsWithPlan(true) .setSkipNonApplicableRulesEnabled(false) .setPrioritizeLargerSpiltsMemoryRevoke(true) - .setRevocableMemorySelectionThreshold(new DataSize(512, MEGABYTE))); + .setRevocableMemorySelectionThreshold(new DataSize(512, MEGABYTE)) + .setAdaptivePartialAggregationEnabled(true) + .setAdaptivePartialAggregationMinRows(100_000) + .setAdaptivePartialAggregationUniqueRowsRatioThreshold(0.8)); } @Test @@ -269,6 +272,9 @@ public void testExplicitPropertyMappings() .put("optimizer.skip-non-applicable-rules-enabled", "true") .put("experimental.prioritize-larger-spilts-memory-revoke", "false") .put("experimental.revocable-memory-selection-threshold", "500MB") + .put("adaptive-partial-aggregation.enabled", "false") + .put("adaptive-partial-aggregation.min-rows", "1") + .put("adaptive-partial-aggregation.unique-rows-ratio-threshold", "0.99") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -375,7 +381,10 @@ public void testExplicitPropertyMappings() .setSkipAttachingStatsWithPlan(false) .setSkipNonApplicableRulesEnabled(true) .setPrioritizeLargerSpiltsMemoryRevoke(false) - .setRevocableMemorySelectionThreshold(new DataSize(500, MEGABYTE)); + .setRevocableMemorySelectionThreshold(new DataSize(500, MEGABYTE)) + .setAdaptivePartialAggregationEnabled(false) + .setAdaptivePartialAggregationMinRows(1) + .setAdaptivePartialAggregationUniqueRowsRatioThreshold(0.99); assertFullMapping(properties, expected); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/OrcReader.java b/presto-orc/src/main/java/io/prestosql/orc/OrcReader.java index 033279447..35186ff6b 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/OrcReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/OrcReader.java @@ -59,7 +59,7 @@ public class OrcReader { - public static final int MAX_BATCH_SIZE = 1024; + public static final int MAX_BATCH_SIZE = 8196; public static final int INITIAL_BATCH_SIZE = 1; public static final int BATCH_SIZE_GROWTH_FACTOR = 2; private static final int EXPECTED_FOOTER_SIZE = 16 * 1024; diff --git a/presto-orc/src/test/java/io/prestosql/orc/TestReadBloomFilter.java b/presto-orc/src/test/java/io/prestosql/orc/TestReadBloomFilter.java index d7405b593..ad18723a2 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/TestReadBloomFilter.java +++ b/presto-orc/src/test/java/io/prestosql/orc/TestReadBloomFilter.java @@ -91,7 +91,7 @@ private static void testType(Type type, List uniqueValues, T inBloomFilte // without predicate a normal block will be created try (OrcRecordReader recordReader = createCustomOrcRecordReader(tempFile, OrcPredicate.TRUE, type, MAX_BATCH_SIZE)) { - assertEquals(recordReader.nextPage().getLoadedPage().getPositionCount(), 1024); + assertEquals(recordReader.nextPage().getLoadedPage().getPositionCount(), 8196); } // predicate for specific value within the min/max range without bloom filter being enabled @@ -100,7 +100,7 @@ private static void testType(Type type, List uniqueValues, T inBloomFilte .build(); try (OrcRecordReader recordReader = createCustomOrcRecordReader(tempFile, noBloomFilterPredicate, type, MAX_BATCH_SIZE)) { - assertEquals(recordReader.nextPage().getLoadedPage().getPositionCount(), 1024); + assertEquals(recordReader.nextPage().getLoadedPage().getPositionCount(), 8196); } // predicate for specific value within the min/max range with bloom filter enabled, but a value not in the bloom filter @@ -120,7 +120,7 @@ private static void testType(Type type, List uniqueValues, T inBloomFilte .build(); try (OrcRecordReader recordReader = createCustomOrcRecordReader(tempFile, matchBloomFilterPredicate, type, MAX_BATCH_SIZE)) { - assertEquals(recordReader.nextPage().getLoadedPage().getPositionCount(), 1024); + assertEquals(recordReader.nextPage().getLoadedPage().getPositionCount(), 8196); } } }