Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metrics for Accumulator and GroupByHash #24015

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import io.airlift.units.Duration;
import io.trino.plugin.base.metrics.DurationTiming;
import io.trino.plugin.base.metrics.LongCount;
import io.trino.spi.metrics.Metrics;

import static java.util.concurrent.TimeUnit.NANOSECONDS;

public class AggregationMetrics
{
@VisibleForTesting
static final String INPUT_ROWS_WITH_PARTIAL_AGGREGATION_DISABLED_METRIC_NAME = "Input rows processed without partial aggregation enabled";
private static final String ACCUMULATOR_TIME_METRIC_NAME = "Accumulator update CPU time";
private static final String GROUP_BY_HASH_TIME_METRIC_NAME = "Group by hash update CPU time";

private long accumulatorTimeNanos;
private long groupByHashTimeNanos;
private long inputRowsProcessedWithPartialAggregationDisabled;

public void recordAccumulatorUpdateTimeSince(long startNanos)
{
accumulatorTimeNanos += System.nanoTime() - startNanos;
}

public void recordGroupByHashUpdateTimeSince(long startNanos)
{
groupByHashTimeNanos += System.nanoTime() - startNanos;
}

public void recordInputRowsProcessedWithPartialAggregationDisabled(long rows)
{
inputRowsProcessedWithPartialAggregationDisabled += rows;
}

public Metrics getMetrics()
{
return new Metrics(ImmutableMap.of(
INPUT_ROWS_WITH_PARTIAL_AGGREGATION_DISABLED_METRIC_NAME, new LongCount(inputRowsProcessedWithPartialAggregationDisabled),
ACCUMULATOR_TIME_METRIC_NAME, new DurationTiming(new Duration(accumulatorTimeNanos, NANOSECONDS)),
GROUP_BY_HASH_TIME_METRIC_NAME, new DurationTiming(new Duration(groupByHashTimeNanos, NANOSECONDS))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ private enum State
private final OperatorContext operatorContext;
private final LocalMemoryContext userMemoryContext;
private final List<Aggregator> aggregates;
private final AggregationMetrics aggregationMetrics = new AggregationMetrics();

private State state = State.NEEDS_INPUT;

Expand All @@ -90,7 +91,7 @@ public AggregationOperator(OperatorContext operatorContext, List<AggregatorFacto
this.userMemoryContext = operatorContext.localUserMemoryContext();

aggregates = aggregatorFactories.stream()
.map(AggregatorFactory::createAggregator)
.map(factory -> factory.createAggregator(aggregationMetrics))
.collect(toImmutableList());
}

Expand All @@ -111,6 +112,7 @@ public void finish()
@Override
public void close()
{
updateOperatorMetrics();
userMemoryContext.setBytes(0);
}

Expand Down Expand Up @@ -144,6 +146,7 @@ public void addInput(Page page)
public Page getOutput()
{
if (state != State.HAS_OUTPUT) {
updateOperatorMetrics();
return null;
}

Expand All @@ -162,6 +165,12 @@ public Page getOutput()
}

state = State.FINISHED;
updateOperatorMetrics();
return pageBuilder.build();
}

private void updateOperatorMetrics()
{
operatorContext.setLatestMetrics(aggregationMetrics.getMetrics());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.DataSize;
import io.trino.memory.context.LocalMemoryContext;
Expand All @@ -26,10 +25,8 @@
import io.trino.operator.aggregation.partial.PartialAggregationController;
import io.trino.operator.aggregation.partial.SkipAggregationBuilder;
import io.trino.operator.scalar.CombineHashFunction;
import io.trino.plugin.base.metrics.LongCount;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.metrics.Metrics;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
Expand All @@ -53,7 +50,6 @@
public class HashAggregationOperator
implements Operator
{
static final String INPUT_ROWS_WITH_PARTIAL_AGGREGATION_DISABLED_METRIC_NAME = "Input rows processed without partial aggregation enabled";
private static final double MERGE_WITH_MEMORY_RATIO = 0.9;

public static class HashAggregationOperatorFactory
Expand Down Expand Up @@ -284,14 +280,14 @@ public OperatorFactory duplicate()
private final SpillerFactory spillerFactory;
private final FlatHashStrategyCompiler flatHashStrategyCompiler;
private final TypeOperators typeOperators;
private final AggregationMetrics aggregationMetrics = new AggregationMetrics();

private final List<Type> types;

private HashAggregationBuilder aggregationBuilder;
private final LocalMemoryContext memoryContext;
private WorkProcessor<Page> outputPages;
private long totalInputRowsProcessed;
private long inputRowsProcessedWithPartialAggregationDisabled;
private boolean finishing;
private boolean finished;

Expand Down Expand Up @@ -392,7 +388,7 @@ public void addInput(Page page)
.map(PartialAggregationController::isPartialAggregationDisabled)
.orElse(false);
if (step.isOutputPartial() && partialAggregationDisabled) {
aggregationBuilder = new SkipAggregationBuilder(groupByChannels, hashChannel, aggregatorFactories, memoryContext);
aggregationBuilder = new SkipAggregationBuilder(groupByChannels, hashChannel, aggregatorFactories, memoryContext, aggregationMetrics);
}
else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
// TODO: We ignore spillEnabled here if any aggregate has ORDER BY clause or DISTINCT because they are not yet implemented for spilling.
Expand All @@ -413,7 +409,8 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
return true;
}
return operatorContext.isWaitingForMemory().isDone();
});
},
aggregationMetrics);
}
else {
aggregationBuilder = new SpillableHashAggregationBuilder(
Expand All @@ -428,7 +425,8 @@ else if (step.isOutputPartial() || !spillEnabled || !isSpillable()) {
memoryLimitForMergeWithMemory,
spillerFactory,
flatHashStrategyCompiler,
typeOperators);
typeOperators,
aggregationMetrics);
}

// assume initial aggregationBuilder is not full
Expand Down Expand Up @@ -537,9 +535,7 @@ public HashAggregationBuilder getAggregationBuilder()
private void closeAggregationBuilder()
{
if (aggregationBuilder instanceof SkipAggregationBuilder) {
inputRowsProcessedWithPartialAggregationDisabled += aggregationInputRowsProcessed;
operatorContext.setLatestMetrics(new Metrics(ImmutableMap.of(
INPUT_ROWS_WITH_PARTIAL_AGGREGATION_DISABLED_METRIC_NAME, new LongCount(inputRowsProcessedWithPartialAggregationDisabled))));
aggregationMetrics.recordInputRowsProcessedWithPartialAggregationDisabled(aggregationInputRowsProcessed);
partialAggregationController.ifPresent(controller -> controller.onFlush(aggregationInputBytesProcessed, aggregationInputRowsProcessed, OptionalLong.empty()));
}
else {
Expand All @@ -549,6 +545,8 @@ private void closeAggregationBuilder()
aggregationInputRowsProcessed = 0;
aggregationUniqueRowsProduced = 0;

operatorContext.setLatestMetrics(aggregationMetrics.getMetrics());

outputPages = null;
if (aggregationBuilder != null) {
aggregationBuilder.close();
Expand Down Expand Up @@ -586,7 +584,7 @@ private Page getGlobalAggregationOutput()
}

for (AggregatorFactory aggregatorFactory : aggregatorFactories) {
aggregatorFactory.createAggregator().evaluate(output.getBlockBuilder(channel));
aggregatorFactory.createAggregator(aggregationMetrics).evaluate(output.getBlockBuilder(channel));
channel++;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator;

import static java.util.Objects.requireNonNull;

public class MeasuredGroupByHashWork<T>
implements Work<T>
{
private final Work<T> delegate;
private final AggregationMetrics metrics;

public MeasuredGroupByHashWork(Work<T> delegate, AggregationMetrics metrics)
{
this.delegate = requireNonNull(delegate, "delegate is null");
this.metrics = requireNonNull(metrics, "metrics is null");
}

@Override
public boolean process()
{
long start = System.nanoTime();
boolean result = delegate.process();
metrics.recordGroupByHashUpdateTimeSince(start);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: alternatively we could use OperationTiming as in

private final OperationTiming addInputTiming = new OperationTiming();
, which would also track number of calls and wall time.

But I'm fine keeping it as is for now

return result;
}

@Override
public T getResult()
{
return delegate.getResult();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.metrics.Metrics;
import io.trino.spi.type.Type;
import io.trino.sql.gen.JoinCompiler;
import io.trino.sql.planner.plan.PlanNodeId;
Expand Down Expand Up @@ -134,6 +135,7 @@ public Factory duplicate()
}

private final WorkProcessor<Page> pages;
private final AggregationMetrics aggregationMetrics = new AggregationMetrics();

private StreamingAggregationOperator(
ProcessorContext processorContext,
Expand All @@ -151,7 +153,8 @@ private StreamingAggregationOperator(
groupByTypes,
groupByChannels,
aggregatorFactories,
joinCompiler));
joinCompiler,
aggregationMetrics));
}

@Override
Expand All @@ -160,6 +163,12 @@ public WorkProcessor<Page> getOutputPages()
return pages;
}

@Override
public Metrics getMetrics()
{
return aggregationMetrics.getMetrics();
}

private static class StreamingAggregation
implements Transformation<Page, Page>
{
Expand All @@ -168,6 +177,7 @@ private static class StreamingAggregation
private final int[] groupByChannels;
private final List<AggregatorFactory> aggregatorFactories;
private final PagesHashStrategy pagesHashStrategy;
private final AggregationMetrics aggregationMetrics;

private List<Aggregator> aggregates;
private final PageBuilder pageBuilder;
Expand All @@ -180,7 +190,8 @@ private StreamingAggregation(
List<Type> groupByTypes,
List<Integer> groupByChannels,
List<AggregatorFactory> aggregatorFactories,
JoinCompiler joinCompiler)
JoinCompiler joinCompiler,
AggregationMetrics aggregationMetrics)
{
requireNonNull(processorContext, "processorContext is null");
this.userMemoryContext = processorContext.getMemoryTrackingContext().localUserMemoryContext();
Expand All @@ -189,7 +200,7 @@ private StreamingAggregation(
this.aggregatorFactories = requireNonNull(aggregatorFactories, "aggregatorFactories is null");

this.aggregates = aggregatorFactories.stream()
.map(AggregatorFactory::createAggregator)
.map(factory -> factory.createAggregator(aggregationMetrics))
.collect(toImmutableList());
this.pageBuilder = new PageBuilder(toTypes(groupByTypes, aggregates));
requireNonNull(joinCompiler, "joinCompiler is null");
Expand All @@ -200,6 +211,7 @@ private StreamingAggregation(
sourceTypes.stream()
.map(type -> new ObjectArrayList<Block>())
.collect(toImmutableList()), OptionalInt.empty());
this.aggregationMetrics = requireNonNull(aggregationMetrics, "aggregationMetrics is null");
}

@Override
Expand Down Expand Up @@ -317,7 +329,7 @@ private void evaluateAndFlushGroup(Page page, int position)
}

aggregates = aggregatorFactories.stream()
.map(AggregatorFactory::createAggregator)
.map(factory -> factory.createAggregator(aggregationMetrics))
.collect(toImmutableList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.operator.aggregation;

import com.google.common.primitives.Ints;
import io.trino.operator.AggregationMetrics;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
Expand All @@ -36,6 +37,7 @@ public class Aggregator
private final int[] inputChannels;
private final OptionalInt maskChannel;
private final AggregationMaskBuilder maskBuilder;
private final AggregationMetrics metrics;

public Aggregator(
Accumulator accumulator,
Expand All @@ -44,7 +46,8 @@ public Aggregator(
Type finalType,
List<Integer> inputChannels,
OptionalInt maskChannel,
AggregationMaskBuilder maskBuilder)
AggregationMaskBuilder maskBuilder,
AggregationMetrics metrics)
{
this.accumulator = requireNonNull(accumulator, "accumulator is null");
this.step = requireNonNull(step, "step is null");
Expand All @@ -53,6 +56,7 @@ public Aggregator(
this.inputChannels = Ints.toArray(requireNonNull(inputChannels, "inputChannels is null"));
this.maskChannel = requireNonNull(maskChannel, "maskChannel is null");
this.maskBuilder = requireNonNull(maskBuilder, "maskBuilder is null");
this.metrics = requireNonNull(metrics, "metrics is null");
checkArgument(step.isInputRaw() || inputChannels.size() == 1, "expected 1 input channel for intermediate aggregation");
}

Expand All @@ -77,10 +81,14 @@ public void processPage(Page page)
if (mask.isSelectNone()) {
return;
}
long start = System.nanoTime();
accumulator.addInput(arguments, mask);
metrics.recordAccumulatorUpdateTimeSince(start);
}
else {
long start = System.nanoTime();
accumulator.addIntermediate(page.getBlock(inputChannels[0]));
metrics.recordAccumulatorUpdateTimeSince(start);
}
}

Expand Down
Loading
Loading