Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[opt](coordinator) optimize parallel degree of shuffle when use nereids #44754

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_collect.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ struct AggregateFunctionCollectListData<StringRef, HasLimit> {
}
max_size = rhs.max_size;

data->insert_range_from(
*rhs.data, 0,
std::min(assert_cast<size_t, TypeCheckOnRelease::DISABLE>(max_size - size()),
rhs.size()));
data->insert_range_from(*rhs.data, 0,
std::min(assert_cast<size_t, TypeCheckOnRelease::DISABLE>(
static_cast<size_t>(max_size - size())),
rhs.size()));
} else {
data->insert_range_from(*rhs.data, 0, rhs.size());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.SetMultimap;
import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -136,6 +135,16 @@ private FragmentIdMapping<DistributedPlan> linkPlans(FragmentIdMapping<Distribut
link.getKey(),
enableShareHashTableForBroadcastJoin
);
for (Entry<DataSink, List<AssignedJob>> kv :
((PipelineDistributedPlan) link.getValue()).getDestinations().entrySet()) {
if (kv.getValue().isEmpty()) {
int sourceFragmentId = link.getValue().getFragmentJob().getFragment().getFragmentId().asInt();
String msg = "Invalid plan which exchange not contains receiver, "
+ "exchange id: " + kv.getKey().getExchNodeId().asInt()
+ ", source fragmentId: " + sourceFragmentId;
throw new IllegalStateException(msg);
}
}
}
}
return plans;
Expand Down Expand Up @@ -184,7 +193,7 @@ private List<AssignedJob> filterInstancesWhichCanReceiveDataFromRemote(
boolean useLocalShuffle = receiverPlan.getInstanceJobs().stream()
.anyMatch(LocalShuffleAssignedJob.class::isInstance);
if (useLocalShuffle) {
return getFirstInstancePerShareScan(receiverPlan);
return getFirstInstancePerWorker(receiverPlan.getInstanceJobs());
} else if (enableShareHashTableForBroadcastJoin && linkNode.isRightChildOfBroadcastHashJoin()) {
return getFirstInstancePerWorker(receiverPlan.getInstanceJobs());
} else {
Expand Down Expand Up @@ -221,17 +230,6 @@ private List<AssignedJob> sortDestinationInstancesByBuckets(
return Arrays.asList(instances);
}

private List<AssignedJob> getFirstInstancePerShareScan(PipelineDistributedPlan plan) {
List<AssignedJob> canReceiveDataFromRemote = Lists.newArrayListWithCapacity(plan.getInstanceJobs().size());
for (AssignedJob instanceJob : plan.getInstanceJobs()) {
LocalShuffleAssignedJob localShuffleJob = (LocalShuffleAssignedJob) instanceJob;
if (!localShuffleJob.receiveDataFromLocal) {
canReceiveDataFromRemote.add(localShuffleJob);
}
}
return canReceiveDataFromRemote;
}

private List<AssignedJob> getFirstInstancePerWorker(List<AssignedJob> instances) {
Map<DistributedPlanWorker, AssignedJob> firstInstancePerWorker = Maps.newLinkedHashMap();
for (AssignedJob instance : instances) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ protected List<AssignedJob> insideMachineParallelization(

// now we should compute how many instances to process the data,
// for example: two instances
int instanceNum = degreeOfParallelism(scanSourceMaxParallel);
int instanceNum = degreeOfParallelism(scanSourceMaxParallel, useLocalShuffleToAddParallel);

if (useLocalShuffleToAddParallel) {
assignLocalShuffleJobs(scanSource, instanceNum, instances, context, worker);
Expand Down Expand Up @@ -129,7 +129,7 @@ protected void assignedDefaultJobs(ScanSource scanSource, int instanceNum, List<
protected void assignLocalShuffleJobs(ScanSource scanSource, int instanceNum, List<AssignedJob> instances,
ConnectContext context, DistributedPlanWorker worker) {
// only generate one instance to scan all data, in this step
List<ScanSource> instanceToScanRanges = scanSource.parallelize(scanNodes, 1);
List<ScanSource> assignedJoinBuckets = scanSource.parallelize(scanNodes, instanceNum);

// when data not big, but aggregation too slow, we will use 1 instance to scan data,
// and use more instances (to ***add parallel***) to process aggregate.
Expand All @@ -144,23 +144,23 @@ protected void assignLocalShuffleJobs(ScanSource scanSource, int instanceNum, Li
// |(share scan node, instance1 will scan all data and local shuffle to other local instances |
// | to parallel compute this data) |
// +------------------------------------------------------------------------------------------------+
ScanSource shareScanSource = instanceToScanRanges.get(0);
ScanSource shareScanSource = assignedJoinBuckets.get(0);

// one scan range generate multiple instances,
// different instances reference the same scan source
int shareScanId = shareScanIdGenerator.getAndIncrement();
ScanSource emptyShareScanSource = shareScanSource.newEmpty();
for (int i = 0; i < instanceNum; i++) {
LocalShuffleAssignedJob instance = new LocalShuffleAssignedJob(
instances.size(), shareScanId, i > 0,
context.nextInstanceId(), this, worker,
i == 0 ? shareScanSource : emptyShareScanSource
instances.size(), shareScanId, context.nextInstanceId(), this, worker,
// only first instance need to scan data
i == 0 ? scanSource : emptyShareScanSource
);
instances.add(instance);
}
}

protected int degreeOfParallelism(int maxParallel) {
protected int degreeOfParallelism(int maxParallel, boolean useLocalShuffleToAddParallel) {
Preconditions.checkArgument(maxParallel > 0, "maxParallel must be positive");
if (!fragment.getDataPartition().isPartitioned()) {
return 1;
Expand All @@ -179,6 +179,10 @@ protected int degreeOfParallelism(int maxParallel) {
}
}

if (useLocalShuffleToAddParallel) {
return Math.max(fragment.getParallelExecNum(), 1);
}

// the scan instance num should not larger than the tablets num
return Math.min(maxParallel, Math.max(fragment.getParallelExecNum(), 1));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,17 @@
*/
public class LocalShuffleAssignedJob extends StaticAssignedJob {
public final int shareScanId;
public final boolean receiveDataFromLocal;

public LocalShuffleAssignedJob(
int indexInUnassignedJob, int shareScanId, boolean receiveDataFromLocal, TUniqueId instanceId,
int indexInUnassignedJob, int shareScanId, TUniqueId instanceId,
UnassignedJob unassignedJob,
DistributedPlanWorker worker, ScanSource scanSource) {
super(indexInUnassignedJob, instanceId, unassignedJob, worker, scanSource);
this.shareScanId = shareScanId;
this.receiveDataFromLocal = receiveDataFromLocal;
}

@Override
protected Map<String, String> extraInfo() {
return ImmutableMap.of("shareScanIndex", String.valueOf(shareScanId));
}

@Override
protected String formatScanSourceString() {
if (receiveDataFromLocal) {
return "read data from first instance of " + getAssignedWorker();
} else {
return super.formatScanSourceString();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ public class LocalShuffleBucketJoinAssignedJob extends LocalShuffleAssignedJob {
private volatile Set<Integer> assignedJoinBucketIndexes;

public LocalShuffleBucketJoinAssignedJob(
int indexInUnassignedJob, int shareScanId, boolean receiveDataFromLocal,
int indexInUnassignedJob, int shareScanId,
TUniqueId instanceId, UnassignedJob unassignedJob,
DistributedPlanWorker worker, ScanSource scanSource,
Set<Integer> assignedJoinBucketIndexes) {
super(indexInUnassignedJob, shareScanId, receiveDataFromLocal, instanceId, unassignedJob, worker, scanSource);
super(indexInUnassignedJob, shareScanId, instanceId, unassignedJob, worker, scanSource);
this.assignedJoinBucketIndexes = Utils.fastToImmutableSet(assignedJoinBucketIndexes);
}

public Set<Integer> getAssignedJoinBucketIndexes() {
return assignedJoinBucketIndexes;
}

public void addAssignedJoinBucketIndexes(Set<Integer> joinBucketIndexes) {
public synchronized void addAssignedJoinBucketIndexes(Set<Integer> joinBucketIndexes) {
this.assignedJoinBucketIndexes = ImmutableSet.<Integer>builder()
.addAll(assignedJoinBucketIndexes)
.addAll(joinBucketIndexes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

/** UnassignedGatherJob */
public class UnassignedGatherJob extends AbstractUnassignedJob {
private boolean useLocalShuffleToAddParallel;
private boolean useSerialSource;

public UnassignedGatherJob(
StatementContext statementContext, PlanFragment fragment,
Expand All @@ -44,24 +44,24 @@ public UnassignedGatherJob(
public List<AssignedJob> computeAssignedJobs(
DistributeContext distributeContext, ListMultimap<ExchangeNode, AssignedJob> inputJobs) {
ConnectContext connectContext = statementContext.getConnectContext();
useLocalShuffleToAddParallel = fragment.useSerialSource(connectContext);
useSerialSource = fragment.useSerialSource(connectContext);

int expectInstanceNum = degreeOfParallelism();

DistributedPlanWorker selectedWorker = distributeContext.selectedWorkers.tryToSelectRandomUsedWorker();
if (useLocalShuffleToAddParallel) {
if (useSerialSource) {
// Using serial source means a serial source operator will be used in this fragment (e.g. data will be
// shuffled to only 1 exchange operator) and then split by followed local exchanger
ImmutableList.Builder<AssignedJob> instances = ImmutableList.builder();

DefaultScanSource shareScan = new DefaultScanSource(ImmutableMap.of());
LocalShuffleAssignedJob receiveDataFromRemote = new LocalShuffleAssignedJob(
0, 0, false,
0, 0,
connectContext.nextInstanceId(), this, selectedWorker, shareScan);

instances.add(receiveDataFromRemote);
for (int i = 1; i < expectInstanceNum; ++i) {
LocalShuffleAssignedJob receiveDataFromLocal = new LocalShuffleAssignedJob(
i, 0, true,
connectContext.nextInstanceId(), this, selectedWorker, shareScan);
i, 0, connectContext.nextInstanceId(), this, selectedWorker, shareScan);
instances.add(receiveDataFromLocal);
}
return instances.build();
Expand All @@ -76,6 +76,6 @@ selectedWorker, new DefaultScanSource(ImmutableMap.of())
}

protected int degreeOfParallelism() {
return useLocalShuffleToAddParallel ? fragment.getParallelExecNum() : 1;
return useSerialSource ? Math.max(1, fragment.getParallelExecNum()) : 1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
Expand Down Expand Up @@ -184,13 +188,23 @@ protected void assignLocalShuffleJobs(ScanSource scanSource, int instanceNum, Li
Set<Integer> assignedJoinBuckets
= ((BucketScanSource) assignJoinBuckets.get(i)).bucketIndexToScanNodeToTablets.keySet();
LocalShuffleBucketJoinAssignedJob instance = new LocalShuffleBucketJoinAssignedJob(
instances.size(), shareScanId, i > 0,
context.nextInstanceId(), this, worker,
instances.size(), shareScanId, context.nextInstanceId(),
this, worker,
i == 0 ? shareScanSource : emptyShareScanSource,
Utils.fastToImmutableSet(assignedJoinBuckets)
);
instances.add(instance);
}

for (int i = assignJoinBuckets.size(); i < instanceNum; ++i) {
LocalShuffleBucketJoinAssignedJob instance = new LocalShuffleBucketJoinAssignedJob(
instances.size(), shareScanId, context.nextInstanceId(),
this, worker, emptyShareScanSource,
// these instance not need to join, because no any bucket assign to it
ImmutableSet.of()
);
instances.add(instance);
}
}

private boolean shouldFillUpInstances(List<HashJoinNode> hashJoinNodes) {
Expand Down Expand Up @@ -224,10 +238,21 @@ private List<AssignedJob> fillUpInstances(List<AssignedJob> instances) {
olapScanNode, randomPartition, missingBucketIndexes);

boolean useLocalShuffle = instances.stream().anyMatch(LocalShuffleAssignedJob.class::isInstance);
Multimap<DistributedPlanWorker, AssignedJob> workerToAssignedJobs = ArrayListMultimap.create();
int maxNumInstancePerWorker = 1;
if (useLocalShuffle) {
for (AssignedJob instance : instances) {
workerToAssignedJobs.put(instance.getAssignedWorker(), instance);
}
for (Collection<AssignedJob> instanceList : workerToAssignedJobs.asMap().values()) {
maxNumInstancePerWorker = Math.max(maxNumInstancePerWorker, instanceList.size());
}
}

List<AssignedJob> newInstances = new ArrayList<>(instances);

for (Entry<DistributedPlanWorker, Collection<Integer>> workerToBuckets : missingBuckets.asMap().entrySet()) {
Map<Integer, Map<ScanNode, ScanRanges>> scanEmptyBuckets = Maps.newLinkedHashMap();
Set<Integer> assignedJoinBuckets = Utils.fastToImmutableSet(workerToBuckets.getValue());
for (Integer bucketIndex : workerToBuckets.getValue()) {
Map<ScanNode, ScanRanges> scanTableWithEmptyData = Maps.newLinkedHashMap();
for (ScanNode scanNode : scanNodes) {
Expand All @@ -236,42 +261,62 @@ private List<AssignedJob> fillUpInstances(List<AssignedJob> instances) {
scanEmptyBuckets.put(bucketIndex, scanTableWithEmptyData);
}

AssignedJob fillUpInstance = null;
DistributedPlanWorker worker = workerToBuckets.getKey();
BucketScanSource scanSource = new BucketScanSource(scanEmptyBuckets);
if (useLocalShuffle) {
// when use local shuffle, we should ensure every backend only process one instance!
// so here we should try to merge the missing buckets into exist instances
boolean mergedBucketsInSameWorkerInstance = false;
for (AssignedJob newInstance : newInstances) {
if (newInstance.getAssignedWorker().equals(worker)) {
BucketScanSource bucketScanSource = (BucketScanSource) newInstance.getScanSource();
bucketScanSource.bucketIndexToScanNodeToTablets.putAll(scanEmptyBuckets);
mergedBucketsInSameWorkerInstance = true;

LocalShuffleBucketJoinAssignedJob instance = (LocalShuffleBucketJoinAssignedJob) newInstance;
instance.addAssignedJoinBucketIndexes(assignedJoinBuckets);
}
List<AssignedJob> sameWorkerInstances = (List) workerToAssignedJobs.get(worker);
if (sameWorkerInstances.isEmpty()) {
sameWorkerInstances = fillUpEmptyInstances(
maxNumInstancePerWorker, scanSource, worker, newInstances, context);
}
if (!mergedBucketsInSameWorkerInstance) {
fillUpInstance = new LocalShuffleBucketJoinAssignedJob(
newInstances.size(), shareScanIdGenerator.getAndIncrement(),
false, context.nextInstanceId(), this, worker, scanSource,
assignedJoinBuckets
);

LocalShuffleBucketJoinAssignedJob firstInstance
= (LocalShuffleBucketJoinAssignedJob ) sameWorkerInstances.get(0);
BucketScanSource firstInstanceScanSource
= (BucketScanSource) firstInstance.getScanSource();
firstInstanceScanSource.bucketIndexToScanNodeToTablets.putAll(scanEmptyBuckets);

Iterator<Integer> assignedJoinBuckets = new LinkedHashSet<>(workerToBuckets.getValue()).iterator();
// make sure the first instance must be assigned some buckets:
// if the first instance assigned some buckets, we start assign empty
// bucket for second instance for balance, or else assign for first instance
int index = firstInstance.getAssignedJoinBucketIndexes().isEmpty() ? -1 : 0;
while (assignedJoinBuckets.hasNext()) {
Integer bucketIndex = assignedJoinBuckets.next();
assignedJoinBuckets.remove();

index = (index + 1) % sameWorkerInstances.size();
LocalShuffleBucketJoinAssignedJob instance
= (LocalShuffleBucketJoinAssignedJob) sameWorkerInstances.get(index);
instance.addAssignedJoinBucketIndexes(ImmutableSet.of(bucketIndex));
}
} else {
fillUpInstance = assignWorkerAndDataSources(
newInstances.add(assignWorkerAndDataSources(
newInstances.size(), context.nextInstanceId(), worker, scanSource
);
}
if (fillUpInstance != null) {
newInstances.add(fillUpInstance);
));
}
}
return newInstances;
}

private List<AssignedJob> fillUpEmptyInstances(
int maxNumInstancePerWorker, BucketScanSource scanSource, DistributedPlanWorker worker,
List<AssignedJob> existsInstances, ConnectContext context) {
int shareScanId = shareScanIdGenerator.getAndIncrement();
List<AssignedJob> newInstances = new ArrayList<>(maxNumInstancePerWorker);
for (int i = 0; i < maxNumInstancePerWorker; i++) {
LocalShuffleBucketJoinAssignedJob newInstance = new LocalShuffleBucketJoinAssignedJob(
existsInstances.size(), shareScanId,
context.nextInstanceId(), this, worker,
scanSource.newEmpty(),
ImmutableSet.of()
);
existsInstances.add(newInstance);
newInstances.add(newInstance);
}
return newInstances;
}

private int fullBucketNum() {
for (ScanNode scanNode : scanNodes) {
if (scanNode instanceof OlapScanNode) {
Expand Down
Loading
Loading