-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Spark Shuffle manager implementation.
- Loading branch information
1 parent
742f0d7
commit ecfedfe
Showing
15 changed files
with
865 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
/* | ||
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. | ||
* See file LICENSE for terms. | ||
*/ | ||
package org.apache.spark.shuffle.ucx | ||
|
||
import java.io.File | ||
import java.nio.{ByteBuffer, MappedByteBuffer} | ||
import java.nio.channels.FileChannel | ||
import java.nio.file.StandardOpenOption | ||
import java.util.concurrent.ConcurrentHashMap | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
import org.openucx.jucx.UcxUtils | ||
import org.apache.spark.shuffle.IndexShuffleBlockResolver | ||
import org.apache.spark.storage.ShuffleBlockId | ||
import org.apache.spark.unsafe.Platform | ||
|
||
|
||
case class UcxShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { | ||
|
||
def this(shuffleBlockId: ShuffleBlockId) = { | ||
this(shuffleBlockId.shuffleId, shuffleBlockId.mapId, shuffleBlockId.reduceId) | ||
} | ||
|
||
def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId | ||
} | ||
|
||
case class BufferBackedBlock(buffer: ByteBuffer) extends Block { | ||
override def getMemoryBlock: MemoryBlock = MemoryBlock(UcxUtils.getAddress(buffer), buffer.capacity()) | ||
} | ||
|
||
class UcxShuffleBlockResolver(conf: UcxShuffleConf, transport: UcxShuffleTransport) | ||
extends IndexShuffleBlockResolver(conf) { | ||
|
||
type MapId = Long | ||
|
||
private val numPartitionsForMapId = new ConcurrentHashMap[MapId, Int] | ||
|
||
override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long, | ||
lengths: Array[Long], dataTmp: File): Unit = { | ||
super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) | ||
val dataFile = getDataFile(shuffleId, mapId) | ||
if (!dataFile.exists()) { | ||
return | ||
} | ||
numPartitionsForMapId.put(mapId, lengths.length) | ||
val fileChannel = FileChannel.open(dataFile.toPath, StandardOpenOption.READ, | ||
StandardOpenOption.WRITE) | ||
val mappedBuffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0L, dataFile.length()) | ||
|
||
val baseAddress = UcxUtils.getAddress(mappedBuffer) | ||
fileChannel.close() | ||
|
||
// Register whole map output file as dummy block | ||
transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE), | ||
BufferBackedBlock(mappedBuffer)) | ||
|
||
val offsetSize = 8 * (lengths.length + 1) | ||
val indexBuf = Platform.allocateDirectBuffer(offsetSize) | ||
|
||
var offset = 0L | ||
indexBuf.putLong(offset) | ||
for (reduceId <- lengths.indices) { | ||
if (lengths(reduceId) > 0) { | ||
transport.register(UcxShuffleBlockId(shuffleId, mapId, reduceId), new Block { | ||
private val memoryBlock = MemoryBlock(baseAddress + offset, lengths(reduceId)) | ||
override def getMemoryBlock: MemoryBlock = memoryBlock | ||
}) | ||
offset += lengths(reduceId) | ||
indexBuf.putLong(offset) | ||
} | ||
} | ||
|
||
if (transport.ucxShuffleConf.protocol == transport.ucxShuffleConf.PROTOCOL.ONE_SIDED) { | ||
transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE), BufferBackedBlock(indexBuf)) | ||
} | ||
} | ||
|
||
override def removeDataByMap(shuffleId: ShuffleId, mapId: Long): Unit = { | ||
transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE)) | ||
transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE)) | ||
|
||
val numRegisteredBlocks = numPartitionsForMapId.get(mapId) | ||
(0 until numRegisteredBlocks) | ||
.foreach(reduceId => transport.unregister(UcxShuffleBlockId(shuffleId, mapId, reduceId))) | ||
super.removeDataByMap(shuffleId, mapId) | ||
} | ||
|
||
override def stop(): Unit = { | ||
numPartitionsForMapId.keys.asScala.foreach(mapId => removeDataByMap(0, mapId)) | ||
super.stop() | ||
} | ||
|
||
} | ||
|
||
object BlocksConstants { | ||
val MAP_FILE: Int = -1 | ||
val INDEX_FILE: Int = -2 | ||
} |
85 changes: 85 additions & 0 deletions
85
src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/* | ||
* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. | ||
* See file LICENSE for terms. | ||
*/ | ||
package org.apache.spark.shuffle.ucx | ||
|
||
import java.util.concurrent.TimeUnit | ||
|
||
import org.openucx.jucx.{UcxException, UcxUtils} | ||
import org.apache.spark.SparkEnv | ||
import org.apache.spark.internal.Logging | ||
import org.apache.spark.internal.config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD | ||
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} | ||
import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager} | ||
import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, BlockId => SparkBlockId} | ||
|
||
class UcxShuffleClient(transport: UcxShuffleTransport, | ||
blocksByAddress: Iterator[(BlockManagerId, Seq[(SparkBlockId, Long, Int)])]) | ||
extends BlockStoreClient with Logging { | ||
|
||
private val accurateThreshold = transport.ucxShuffleConf.conf.getSizeAsBytes(SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key) | ||
|
||
private val blockSizes: Map[SparkBlockId, Long] = blocksByAddress | ||
.withFilter { case (blockManagerId, _) => blockManagerId != SparkEnv.get.blockManager.blockManagerId } | ||
.flatMap { | ||
case (blockManagerId, blocks) => | ||
val blockIds = blocks.map { | ||
case (blockId, _, _) => | ||
val sparkBlockId = blockId.asInstanceOf[ShuffleBlockId] | ||
UcxShuffleBlockId(sparkBlockId.shuffleId, sparkBlockId.mapId, sparkBlockId.reduceId) | ||
} | ||
if (!transport.ucxShuffleConf.pinMemory) { | ||
transport.prefetchBlocks(blockManagerId.executorId, blockIds) | ||
} | ||
blocks.map { | ||
case (blockId, length, _) => | ||
if (length > accurateThreshold) { | ||
(blockId, (length * 1.2).toLong) | ||
} else { | ||
(blockId, accurateThreshold) | ||
} | ||
} | ||
}.toMap | ||
|
||
override def fetchBlocks(host: String, port: Int, execId: String, | ||
blockIds: Array[String], listener: BlockFetchingListener, | ||
downloadFileManager: DownloadFileManager): Unit = { | ||
val ucxBlockIds = new Array[BlockId](blockIds.length) | ||
val memoryBlocks = new Array[MemoryBlock](blockIds.length) | ||
val callbacks = new Array[OperationCallback](blockIds.length) | ||
for (i <- blockIds.indices) { | ||
val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[ShuffleBlockId] | ||
if (!blockSizes.contains(blockId)) { | ||
throw new UcxException(s"No $blockId found in MapOutput blocks: ${blockSizes.keys.mkString(",")}") | ||
} | ||
val resultMemory = transport.memoryPool.get(blockSizes(blockId)) | ||
ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, blockId.reduceId) | ||
memoryBlocks(i) = MemoryBlock(resultMemory.address, blockSizes(blockId)) | ||
callbacks(i) = (result: OperationResult) => { | ||
if (result.getStatus == OperationStatus.SUCCESS) { | ||
val stats = result.getStats.get | ||
logInfo(s" Received block ${ucxBlockIds(i)} " + | ||
s"of size: ${stats.recvSize} " + | ||
s"in ${TimeUnit.NANOSECONDS.toMillis(stats.getElapsedTimeNs)} ms") | ||
val buffer = UcxUtils.getByteBufferView(resultMemory.address, result.getStats.get.recvSize.toInt) | ||
listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) { | ||
override def release: ManagedBuffer = { | ||
transport.memoryPool.put(resultMemory) | ||
this | ||
} | ||
}) | ||
} else { | ||
logError(s"Error fetching block $blockId of size ${blockSizes(blockId)}:" + | ||
s" ${result.getError.getMessage}") | ||
throw new UcxException(result.getError.getMessage) | ||
} | ||
} | ||
} | ||
transport.fetchBlocksByBlockIds(execId, ucxBlockIds, memoryBlocks, callbacks) | ||
} | ||
|
||
override def close(): Unit = { | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
102 changes: 102 additions & 0 deletions
102
src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
/* | ||
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. | ||
* See file LICENSE for terms. | ||
*/ | ||
package org.apache.spark.shuffle.ucx | ||
|
||
import scala.concurrent.ExecutionContext.Implicits.global | ||
import scala.util.Success | ||
|
||
import org.apache.spark.rpc.RpcEnv | ||
import org.apache.spark.shuffle._ | ||
import org.apache.spark.shuffle.sort.SortShuffleManager | ||
import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch | ||
import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} | ||
import org.apache.spark.shuffle.ucx.rpc.{UcxDriverRpcEndpoint, UcxExecutorRpcEndpoint} | ||
import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer | ||
import org.apache.spark.util.RpcUtils | ||
import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, TaskContext} | ||
|
||
|
||
class UcxShuffleManager(conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) { | ||
|
||
val ucxShuffleConf = new UcxShuffleConf(conf) | ||
|
||
lazy val ucxShuffleTransport: UcxShuffleTransport = if (!isDriver) { | ||
new UcxShuffleTransport(ucxShuffleConf, "init") | ||
} else { | ||
null | ||
} | ||
|
||
@volatile private var initialized: Boolean = false | ||
|
||
override val shuffleBlockResolver = | ||
new UcxShuffleBlockResolver(ucxShuffleConf, ucxShuffleTransport) | ||
|
||
logInfo("Starting UcxShuffleManager") | ||
|
||
def initTransport(): Unit = this.synchronized { | ||
if (!initialized) { | ||
val driverEndpointName = "ucx-shuffle-driver" | ||
if (isDriver) { | ||
val rpcEnv = SparkEnv.get.rpcEnv | ||
val driverEndpoint = new UcxDriverRpcEndpoint(rpcEnv) | ||
rpcEnv.setupEndpoint(driverEndpointName, driverEndpoint) | ||
} else { | ||
val blockManager = SparkEnv.get.blockManager.blockManagerId | ||
ucxShuffleTransport.executorId = blockManager.executorId | ||
val rpcEnv = RpcEnv.create("ucx-rpc-env", blockManager.host, blockManager.host, | ||
blockManager.port, conf, new SecurityManager(conf), 1, clientMode=false) | ||
logDebug("Initializing ucx transport") | ||
val address = ucxShuffleTransport.init() | ||
val executorEndpoint = new UcxExecutorRpcEndpoint(rpcEnv, ucxShuffleTransport) | ||
val endpoint = rpcEnv.setupEndpoint( | ||
s"ucx-shuffle-executor-${blockManager.executorId}", | ||
executorEndpoint) | ||
|
||
val driverEndpoint = RpcUtils.makeDriverRef(driverEndpointName, conf, rpcEnv) | ||
driverEndpoint.ask[IntroduceAllExecutors](ExecutorAdded(blockManager.executorId, | ||
endpoint, new SerializableDirectBuffer(address))) | ||
.andThen{ | ||
case Success(msg) => | ||
logInfo(s"Receive reply $msg") | ||
executorEndpoint.receive(msg) | ||
} | ||
} | ||
initialized = true | ||
} | ||
} | ||
|
||
override def getReader[K, C](handle: ShuffleHandle, | ||
startPartition: Int, | ||
endPartition: Int, | ||
context: TaskContext, | ||
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { | ||
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( | ||
handle.shuffleId, startPartition, endPartition) | ||
new UcxShuffleReader(ucxShuffleTransport, | ||
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics, | ||
shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) | ||
} | ||
|
||
override def getReaderForRange[K, C]( handle: ShuffleHandle, | ||
startMapIndex: Int, | ||
endMapIndex: Int, | ||
startPartition: Int, | ||
endPartition: Int, | ||
context: TaskContext, | ||
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { | ||
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange( | ||
handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) | ||
new UcxShuffleReader(ucxShuffleTransport, | ||
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics, | ||
shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) | ||
} | ||
|
||
override def stop(): Unit = { | ||
if (ucxShuffleTransport != null) { | ||
ucxShuffleTransport.close() | ||
} | ||
super.stop() | ||
} | ||
} |
Oops, something went wrong.