Skip to content

Commit

Permalink
Spark Shuffle manager implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
petro-rudenko committed Oct 30, 2020
1 parent 742f0d7 commit 71fc252
Show file tree
Hide file tree
Showing 13 changed files with 714 additions and 68 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ See file LICENSE for terms.
</dependencies>

<build>
<finalName>${project.artifactId}-${project.version}-for-spark-${spark.version}</finalName>
<finalName>${project.artifactId}-${project.version}-for-spark-3.0</finalName>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx

import java.io.File
import java.nio.MappedByteBuffer
import java.nio.channels.FileChannel
import java.nio.file.StandardOpenOption
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._

import org.openucx.jucx.UcxUtils
import org.apache.spark.shuffle.IndexShuffleBlockResolver
import org.apache.spark.storage.ShuffleBlockId


case class UcxShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId {

def this(shuffleBlockId: ShuffleBlockId) = {
this(shuffleBlockId.shuffleId, shuffleBlockId.mapId, shuffleBlockId.reduceId)
}

def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}

class UcxShuffleBlockResolver(conf: UcxShuffleConf, transport: UcxShuffleTransport)
extends IndexShuffleBlockResolver(conf) {

type MapId = Long

private val numPartitionsForMapId = new ConcurrentHashMap[MapId, Int]
private val mapIdToMappedBuffer = new ConcurrentHashMap[MapId, MappedByteBuffer]

override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long,
lengths: Array[Long], dataTmp: File): Unit = {
super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
val dataFile = getDataFile(shuffleId, mapId)
if (!dataFile.exists()) {
return
}
numPartitionsForMapId.put(mapId, lengths.length)
val fileChannel = FileChannel.open(dataFile.toPath, StandardOpenOption.READ,
StandardOpenOption.WRITE)
val mappedBuffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0L, dataFile.length())
mapIdToMappedBuffer.put(mapId, mappedBuffer)
val baseAddress = UcxUtils.getAddress(mappedBuffer)
fileChannel.close()

// Register whole map output file as dummy block
transport.register(UcxShuffleBlockId(shuffleId, mapId, -1), new Block {
override def getMemoryBlock: MemoryBlock = MemoryBlock(baseAddress, dataFile.length())
})

var offset = 0L
for (reduceId <- lengths.indices) {
if (lengths(reduceId) > 0) {
transport.register(UcxShuffleBlockId(shuffleId, mapId, reduceId), new Block {
private val memoryBlock = MemoryBlock(baseAddress + offset, lengths(reduceId))
override def getMemoryBlock: MemoryBlock = memoryBlock
})
offset += lengths(reduceId)
}
}

}

override def removeDataByMap(shuffleId: ShuffleId, mapId: Long): Unit = {
transport.unregister(UcxShuffleBlockId(shuffleId, mapId, -1))
mapIdToMappedBuffer.remove(mapId)
val numRegisteredBlocks = numPartitionsForMapId.get(mapId)
(0 until numRegisteredBlocks)
.foreach(reduceId => transport.unregister(UcxShuffleBlockId(shuffleId, mapId, reduceId)))
super.removeDataByMap(shuffleId, mapId)
}

override def stop(): Unit = {
numPartitionsForMapId.keys.asScala.foreach(mapId => removeDataByMap(0, mapId))
super.stop()
}

}
85 changes: 85 additions & 0 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx

import java.util.concurrent.TimeUnit

import org.openucx.jucx.{UcxException, UcxUtils}
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager}
import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, BlockId => SparkBlockId}

class UcxShuffleClient(transport: UcxShuffleTransport,
blocksByAddress: Iterator[(BlockManagerId, Seq[(SparkBlockId, Long, Int)])])
extends BlockStoreClient with Logging {

private val accurateThreshold = transport.ucxShuffleConf.conf.getSizeAsBytes(SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key)

private val blockSizes: Map[SparkBlockId, Long] = blocksByAddress
.withFilter { case (blockManagerId, _) => blockManagerId != SparkEnv.get.blockManager.blockManagerId }
.flatMap {
case (blockManagerId, blocks) =>
val blockIds = blocks.map {
case (blockId, _, _) =>
val sparkBlockId = blockId.asInstanceOf[ShuffleBlockId]
UcxShuffleBlockId(sparkBlockId.shuffleId, sparkBlockId.mapId, sparkBlockId.reduceId)
}
if (!transport.ucxShuffleConf.pinMemory) {
// transport.prefetchBlocks(blockManagerId.executorId, blockIds)
}
blocks.map {
case (blockId, length, _) =>
if (length > accurateThreshold) {
(blockId, (length * 1.2).toLong)
} else {
(blockId, accurateThreshold)
}
}
}.toMap

override def fetchBlocks(host: String, port: Int, execId: String,
blockIds: Array[String], listener: BlockFetchingListener,
downloadFileManager: DownloadFileManager): Unit = {
val ucxBlockIds = new Array[BlockId](blockIds.length)
val memoryBlocks = new Array[MemoryBlock](blockIds.length)
val callbacks = new Array[OperationCallback](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[ShuffleBlockId]
if (!blockSizes.contains(blockId)) {
throw new UcxException(s"No $blockId found in MapOutput blocks: ${blockSizes.keys.mkString(",")}")
}
val resultMemory = transport.memoryPool.get(blockSizes(blockId))
ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, blockId.reduceId)
memoryBlocks(i) = MemoryBlock(resultMemory.address, blockSizes(blockId))
callbacks(i) = (result: OperationResult) => {
if (result.getStatus == OperationStatus.SUCCESS) {
val stats = result.getStats.get
logInfo(s" Received block ${ucxBlockIds(i)} " +
s"of size: ${stats.recvSize} " +
s"in ${TimeUnit.NANOSECONDS.toMillis(stats.getElapsedTimeNs)} ms")
val buffer = UcxUtils.getByteBufferView(resultMemory.address, result.getStats.get.recvSize.toInt)
listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) {
override def release: ManagedBuffer = {
transport.memoryPool.put(resultMemory)
this
}
})
} else {
logError(s"Error fetching block $blockId of size ${blockSizes(blockId)}:" +
s" ${result.getError.getMessage}")
throw new UcxException(result.getError.getMessage)
}
}
}
transport.fetchBlocksByBlockIds(execId, ucxBlockIds, memoryBlocks, callbacks)
}

override def close(): Unit = {

}
}
22 changes: 17 additions & 5 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@ import org.apache.spark.util.Utils
class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
private def getUcxConf(name: String) = s"spark.shuffle.ucx.$name"

private val PROTOCOL =
object PROTOCOL extends Enumeration {
val ONE_SIDED, RNDV = Value
}

private lazy val PROTOCOL_CONF =
ConfigBuilder(getUcxConf("protocol"))
.doc("Which protocol to use: rndv (default), one-sided")
.doc("Which protocol to use: RNDV (default), ONE-SIDED")
.stringConf
.checkValue(protocol => protocol == "rndv" || protocol == "one-sided",
"Invalid protocol. Valid options: rndv / one-sided.")
.createWithDefault("rndv")
.transform(_.toUpperCase.replace("-", "_"))
.createWithDefault("RNDV")

private val MEMORY_PINNING =
private lazy val MEMORY_PINNING =
ConfigBuilder(getUcxConf("memoryPinning"))
.doc("Whether to pin whole shuffle data in memory")
.booleanConf
Expand Down Expand Up @@ -67,7 +72,14 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
lazy val minRegistrationSize: Int = conf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key,
MIN_REGISTRATION_SIZE.defaultValueString).toInt

lazy val protocol: String = conf.get(PROTOCOL.key, PROTOCOL.defaultValueString)
private lazy val USE_ODP =
ConfigBuilder(getUcxConf("useOdp"))
.doc("Whether to use on demand paging feature, to avoid memory pinning")
.booleanConf
.createWithDefault(false)

lazy val protocol: PROTOCOL.Value = PROTOCOL.withName(
conf.get(PROTOCOL_CONF.key, PROTOCOL_CONF.defaultValueString))

lazy val useOdp: Boolean = conf.getBoolean(getUcxConf("memory.useOdp"), defaultValue = false)

Expand Down
102 changes: 102 additions & 0 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx

import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.Success

import org.apache.spark.rpc.RpcEnv
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors}
import org.apache.spark.shuffle.ucx.rpc.{UcxDriverRpcEndpoint, UcxExecutorRpcEndpoint}
import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer
import org.apache.spark.util.RpcUtils
import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, TaskContext}


class UcxShuffleManager(conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) {

val ucxShuffleConf = new UcxShuffleConf(conf)

lazy val ucxShuffleTransport: UcxShuffleTransport = if (!isDriver) {
new UcxShuffleTransport(ucxShuffleConf, "init")
} else {
null
}

@volatile private var initialized: Boolean = false

override val shuffleBlockResolver =
new UcxShuffleBlockResolver(ucxShuffleConf, ucxShuffleTransport)

logInfo("Starting UcxShuffleManager")

def initTransport(): Unit = this.synchronized {
if (!initialized) {
val driverEndpointName = "ucx-shuffle-driver"
if (isDriver) {
val rpcEnv = SparkEnv.get.rpcEnv
val driverEndpoint = new UcxDriverRpcEndpoint(rpcEnv)
rpcEnv.setupEndpoint(driverEndpointName, driverEndpoint)
} else {
val blockManager = SparkEnv.get.blockManager.blockManagerId
ucxShuffleTransport.executorId = blockManager.executorId
val rpcEnv = RpcEnv.create("ucx-rpc-env", blockManager.host, blockManager.host,
blockManager.port, conf, new SecurityManager(conf), 1, clientMode=false)
logDebug("Initializing ucx transport")
val address = ucxShuffleTransport.init()
val executorEndpoint = new UcxExecutorRpcEndpoint(rpcEnv, ucxShuffleTransport)
val endpoint = rpcEnv.setupEndpoint(
s"ucx-shuffle-executor-${blockManager.executorId}",
executorEndpoint)

val driverEndpoint = RpcUtils.makeDriverRef(driverEndpointName, conf, rpcEnv)
driverEndpoint.ask[IntroduceAllExecutors](ExecutorAdded(blockManager.executorId,
endpoint, new SerializableDirectBuffer(address)))
.andThen{
case Success(msg) =>
logInfo(s"Receive reply $msg")
executorEndpoint.receive(msg)
}
}
initialized = true
}
}

override def getReader[K, C](handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
handle.shuffleId, startPartition, endPartition)
new UcxShuffleReader(ucxShuffleTransport,
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
}

override def getReaderForRange[K, C]( handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange(
handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
new UcxShuffleReader(ucxShuffleTransport,
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
}

override def stop(): Unit = {
if (ucxShuffleTransport != null) {
ucxShuffleTransport.close()
}
super.stop()
}
}
Loading

0 comments on commit 71fc252

Please sign in to comment.