Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Spark Shuffle manager implementation. #26

Open
wants to merge 1 commit into
base: unified-api
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ See file LICENSE for terms.
</dependencies>

<build>
<finalName>${project.artifactId}-${project.version}-for-spark-${spark.version}</finalName>
<finalName>${project.artifactId}-${project.version}-for-spark-3.0</finalName>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx

import java.io.File
import java.nio.{ByteBuffer, MappedByteBuffer}
import java.nio.channels.FileChannel
import java.nio.file.StandardOpenOption
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._

import org.openucx.jucx.UcxUtils
import org.apache.spark.shuffle.IndexShuffleBlockResolver
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.unsafe.Platform


case class UcxShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId {

def this(shuffleBlockId: ShuffleBlockId) = {
this(shuffleBlockId.shuffleId, shuffleBlockId.mapId, shuffleBlockId.reduceId)
}

def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}

case class BufferBackedBlock(buffer: ByteBuffer) extends Block {
override def getMemoryBlock: MemoryBlock = MemoryBlock(UcxUtils.getAddress(buffer), buffer.capacity())
}

class UcxShuffleBlockResolver(conf: UcxShuffleConf, transport: UcxShuffleTransport)
extends IndexShuffleBlockResolver(conf) {

type MapId = Long

private val numPartitionsForMapId = new ConcurrentHashMap[MapId, Int]

override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long,
lengths: Array[Long], dataTmp: File): Unit = {
super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
val dataFile = getDataFile(shuffleId, mapId)
if (!dataFile.exists()) {
return
}
numPartitionsForMapId.put(mapId, lengths.length)
val fileChannel = FileChannel.open(dataFile.toPath, StandardOpenOption.READ,
StandardOpenOption.WRITE)
val mappedBuffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0L, dataFile.length())

val baseAddress = UcxUtils.getAddress(mappedBuffer)
fileChannel.close()

// Register whole map output file as dummy block
transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE),
BufferBackedBlock(mappedBuffer))

val offsetSize = 8 * (lengths.length + 1)
val indexBuf = Platform.allocateDirectBuffer(offsetSize)

var offset = 0L
indexBuf.putLong(offset)
for (reduceId <- lengths.indices) {
if (lengths(reduceId) > 0) {
transport.register(UcxShuffleBlockId(shuffleId, mapId, reduceId), new Block {
private val memoryBlock = MemoryBlock(baseAddress + offset, lengths(reduceId))
override def getMemoryBlock: MemoryBlock = memoryBlock
})
offset += lengths(reduceId)
indexBuf.putLong(offset)
}
}

if (transport.ucxShuffleConf.protocol == transport.ucxShuffleConf.PROTOCOL.ONE_SIDED) {
transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE), BufferBackedBlock(indexBuf))
}
}

override def removeDataByMap(shuffleId: ShuffleId, mapId: Long): Unit = {
transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE))
transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE))

val numRegisteredBlocks = numPartitionsForMapId.get(mapId)
(0 until numRegisteredBlocks)
.foreach(reduceId => transport.unregister(UcxShuffleBlockId(shuffleId, mapId, reduceId)))
super.removeDataByMap(shuffleId, mapId)
}

override def stop(): Unit = {
numPartitionsForMapId.keys.asScala.foreach(mapId => removeDataByMap(0, mapId))
super.stop()
}

}

object BlocksConstants {
val MAP_FILE: Int = -1
val INDEX_FILE: Int = -2
}
85 changes: 85 additions & 0 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx

import java.util.concurrent.TimeUnit

import org.openucx.jucx.{UcxException, UcxUtils}
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager}
import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, BlockId => SparkBlockId}

class UcxShuffleClient(transport: UcxShuffleTransport,
blocksByAddress: Iterator[(BlockManagerId, Seq[(SparkBlockId, Long, Int)])])
extends BlockStoreClient with Logging {

private val accurateThreshold = transport.ucxShuffleConf.conf.getSizeAsBytes(SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key)

private val blockSizes: Map[SparkBlockId, Long] = blocksByAddress
.withFilter { case (blockManagerId, _) => blockManagerId != SparkEnv.get.blockManager.blockManagerId }
.flatMap {
case (blockManagerId, blocks) =>
val blockIds = blocks.map {
case (blockId, _, _) =>
val sparkBlockId = blockId.asInstanceOf[ShuffleBlockId]
UcxShuffleBlockId(sparkBlockId.shuffleId, sparkBlockId.mapId, sparkBlockId.reduceId)
}
if (!transport.ucxShuffleConf.pinMemory) {
transport.prefetchBlocks(blockManagerId.executorId, blockIds)
}
blocks.map {
case (blockId, length, _) =>
if (length > accurateThreshold) {
(blockId, (length * 1.2).toLong)
} else {
(blockId, accurateThreshold)
}
}
}.toMap

override def fetchBlocks(host: String, port: Int, execId: String,
blockIds: Array[String], listener: BlockFetchingListener,
downloadFileManager: DownloadFileManager): Unit = {
val ucxBlockIds = new Array[BlockId](blockIds.length)
val memoryBlocks = new Array[MemoryBlock](blockIds.length)
val callbacks = new Array[OperationCallback](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[ShuffleBlockId]
if (!blockSizes.contains(blockId)) {
throw new UcxException(s"No $blockId found in MapOutput blocks: ${blockSizes.keys.mkString(",")}")
}
val resultMemory = transport.memoryPool.get(blockSizes(blockId))
ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, blockId.reduceId)
memoryBlocks(i) = MemoryBlock(resultMemory.address, blockSizes(blockId))
callbacks(i) = (result: OperationResult) => {
if (result.getStatus == OperationStatus.SUCCESS) {
val stats = result.getStats.get
logInfo(s" Received block ${ucxBlockIds(i)} " +
s"of size: ${stats.recvSize} " +
s"in ${TimeUnit.NANOSECONDS.toMillis(stats.getElapsedTimeNs)} ms")
val buffer = UcxUtils.getByteBufferView(resultMemory.address, result.getStats.get.recvSize.toInt)
listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) {
override def release: ManagedBuffer = {
transport.memoryPool.put(resultMemory)
this
}
})
} else {
logError(s"Error fetching block $blockId of size ${blockSizes(blockId)}:" +
s" ${result.getError.getMessage}")
throw new UcxException(result.getError.getMessage)
}
}
}
transport.fetchBlocksByBlockIds(execId, ucxBlockIds, memoryBlocks, callbacks)
}

override def close(): Unit = {

}
}
34 changes: 27 additions & 7 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@ import org.apache.spark.util.Utils
class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
private def getUcxConf(name: String) = s"spark.shuffle.ucx.$name"

private val PROTOCOL =
object PROTOCOL extends Enumeration {
val ONE_SIDED, RNDV = Value
}

private lazy val PROTOCOL_CONF =
ConfigBuilder(getUcxConf("protocol"))
.doc("Which protocol to use: rndv (default), one-sided")
.doc("Which protocol to use: RNDV (default), ONE-SIDED")
.stringConf
.checkValue(protocol => protocol == "rndv" || protocol == "one-sided",
"Invalid protocol. Valid options: rndv / one-sided.")
.createWithDefault("rndv")
.transform(_.toUpperCase.replace("-", "_"))
.createWithDefault("RNDV")

private val MEMORY_PINNING =
private lazy val MEMORY_PINNING =
ConfigBuilder(getUcxConf("memoryPinning"))
.doc("Whether to pin whole shuffle data in memory")
.booleanConf
Expand All @@ -30,14 +35,14 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
ConfigBuilder(getUcxConf("maxWorkerSize"))
.doc("Maximum size of worker address in bytes")
.bytesConf(ByteUnit.BYTE)
.createWithDefault(1000)
.createWithDefault(1024)

lazy val RPC_MESSAGE_SIZE: ConfigEntry[Long] =
ConfigBuilder(getUcxConf("rpcMessageSize"))
.doc("Size of RPC message to send from fetchBlockByBlockId. Must contain ")
.bytesConf(ByteUnit.BYTE)
.checkValue(size => size > maxWorkerAddressSize,
"Rpc message must contain workerAddress")
"Rpc message must contain at least workerAddress")
.createWithDefault(2000)

// Memory Pool
Expand All @@ -58,6 +63,12 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
.intConf
.createWithDefault(5)

private lazy val USE_SOCKADDR =
ConfigBuilder(getUcxConf("useSockAddr"))
.doc("Whether to use socket address to connect executors.")
.booleanConf
.createWithDefault(true)

private lazy val MIN_REGISTRATION_SIZE =
ConfigBuilder(getUcxConf("memory.minAllocationSize"))
.doc("Minimal memory registration size in memory pool.")
Expand All @@ -67,7 +78,14 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
lazy val minRegistrationSize: Int = conf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key,
MIN_REGISTRATION_SIZE.defaultValueString).toInt

lazy val protocol: String = conf.get(PROTOCOL.key, PROTOCOL.defaultValueString)
private lazy val USE_ODP =
ConfigBuilder(getUcxConf("useOdp"))
.doc("Whether to use on demand paging feature, to avoid memory pinning")
.booleanConf
.createWithDefault(false)

lazy val protocol: PROTOCOL.Value = PROTOCOL.withName(
conf.get(PROTOCOL_CONF.key, PROTOCOL_CONF.defaultValueString))

lazy val useOdp: Boolean = conf.getBoolean(getUcxConf("memory.useOdp"), defaultValue = false)

Expand All @@ -83,6 +101,8 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf {

lazy val recvQueueSize: Int = conf.getInt(RECV_QUEUE_SIZE.key, RECV_QUEUE_SIZE.defaultValue.get)

lazy val useSockAddr: Boolean = conf.getBoolean(USE_SOCKADDR.key, USE_SOCKADDR.defaultValue.get)

lazy val preallocateBuffersMap: Map[Long, Int] = {
conf.get(PREALLOCATE_BUFFERS).split(",").withFilter(s => !s.isEmpty)
.map(entry => entry.split(":") match {
Expand Down
102 changes: 102 additions & 0 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx

import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.Success

import org.apache.spark.rpc.RpcEnv
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors}
import org.apache.spark.shuffle.ucx.rpc.{UcxDriverRpcEndpoint, UcxExecutorRpcEndpoint}
import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer
import org.apache.spark.util.RpcUtils
import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, TaskContext}


class UcxShuffleManager(conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) {

val ucxShuffleConf = new UcxShuffleConf(conf)

lazy val ucxShuffleTransport: UcxShuffleTransport = if (!isDriver) {
new UcxShuffleTransport(ucxShuffleConf, "init")
} else {
null
}

@volatile private var initialized: Boolean = false

override val shuffleBlockResolver =
new UcxShuffleBlockResolver(ucxShuffleConf, ucxShuffleTransport)

logInfo("Starting UcxShuffleManager")

def initTransport(): Unit = this.synchronized {
if (!initialized) {
val driverEndpointName = "ucx-shuffle-driver"
if (isDriver) {
val rpcEnv = SparkEnv.get.rpcEnv
val driverEndpoint = new UcxDriverRpcEndpoint(rpcEnv)
rpcEnv.setupEndpoint(driverEndpointName, driverEndpoint)
} else {
val blockManager = SparkEnv.get.blockManager.blockManagerId
ucxShuffleTransport.executorId = blockManager.executorId
val rpcEnv = RpcEnv.create("ucx-rpc-env", blockManager.host, blockManager.host,
blockManager.port, conf, new SecurityManager(conf), 1, clientMode=false)
logDebug("Initializing ucx transport")
val address = ucxShuffleTransport.init()
val executorEndpoint = new UcxExecutorRpcEndpoint(rpcEnv, ucxShuffleTransport)
val endpoint = rpcEnv.setupEndpoint(
s"ucx-shuffle-executor-${blockManager.executorId}",
executorEndpoint)

val driverEndpoint = RpcUtils.makeDriverRef(driverEndpointName, conf, rpcEnv)
driverEndpoint.ask[IntroduceAllExecutors](ExecutorAdded(blockManager.executorId,
endpoint, new SerializableDirectBuffer(address)))
.andThen{
case Success(msg) =>
logInfo(s"Receive reply $msg")
executorEndpoint.receive(msg)
}
}
initialized = true
}
}

override def getReader[K, C](handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
handle.shuffleId, startPartition, endPartition)
new UcxShuffleReader(ucxShuffleTransport,
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
}

override def getReaderForRange[K, C]( handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange(
handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
new UcxShuffleReader(ucxShuffleTransport,
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
}

override def stop(): Unit = {
if (ucxShuffleTransport != null) {
ucxShuffleTransport.close()
}
super.stop()
}
}
Loading