Skip to content

Commit

Permalink
add thread pool + progress thread
Browse files Browse the repository at this point in the history
  • Loading branch information
JeynmannZ committed May 29, 2024
1 parent c55a642 commit bb17180
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 155 deletions.
110 changes: 59 additions & 51 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.ucx.memory.UcxHostBounceBuffersPool
import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread
import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils}
import org.apache.spark.util.ThreadUtils
import org.apache.spark.shuffle.utils.UnsafeUtils
import org.openucx.jucx.UcxException
import org.openucx.jucx.ucp._
import org.openucx.jucx.ucs.UcsConstants

import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.concurrent.TrieMap
import scala.collection.mutable

Expand Down Expand Up @@ -53,7 +55,7 @@ class UcxStats extends OperationStats {
}

case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
override def serializedSize: Int = 12
override def serializedSize: Int = UcxShuffleBockId.serializedSize

override def serialize(byteBuffer: ByteBuffer): Unit = {
byteBuffer.putInt(shuffleId)
Expand All @@ -63,6 +65,8 @@ case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends B
}

object UcxShuffleBockId {
val serializedSize: Int = 12

def deserialize(byteBuffer: ByteBuffer): UcxShuffleBockId = {
val shuffleId = byteBuffer.getInt
val mapId = byteBuffer.getInt
Expand All @@ -81,13 +85,20 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
val endpoints = mutable.Set.empty[UcpEndpoint]
val executorAddresses = new TrieMap[ExecutorId, ByteBuffer]

private var allocatedClientThreads: Array[UcxWorkerThread] = _
private var allocatedServerThreads: Array[UcxWorkerThread] = _
private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _
private var clientWorkerId = new AtomicInteger()

private var allocatedServerWorkers: Array[UcxWorkerWrapper] = _
private val serverWorkerId = new AtomicInteger()
private var serverLocal = new ThreadLocal[UcxWorkerWrapper]

private val registeredBlocks = new TrieMap[BlockId, Block]
private var progressThread: Thread = _
var hostBounceBufferMemoryPool: UcxHostBounceBuffersPool = _

private[spark] lazy val replyThreadPool = ThreadUtils.newForkJoinPool(
"UcxListenerThread", ucxShuffleConf.numListenerThreads)

private val errorHandler = new UcpEndpointErrorHandler {
override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = {
if (errorCode == UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) {
Expand Down Expand Up @@ -122,13 +133,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
globalWorker = ucxContext.newWorker(ucpWorkerParams)
hostBounceBufferMemoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext)

allocatedServerThreads = new Array[UcxWorkerThread](ucxShuffleConf.numListenerThreads)
allocatedServerWorkers = new Array[UcxWorkerWrapper](ucxShuffleConf.numListenerThreads)
logInfo(s"Allocating ${ucxShuffleConf.numListenerThreads} server workers")
for (i <- 0 until ucxShuffleConf.numListenerThreads) {
val worker = ucxContext.newWorker(ucpWorkerParams)
val workerWrapper = UcxWorkerWrapper(worker, this, isClientWorker = false, i.toLong)
allocatedServerThreads(i) = new UcxWorkerThread(workerWrapper)
allocatedServerThreads(i).start()
allocatedServerWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = false, i.toLong)
}

val Array(host, port) = ucxShuffleConf.listenerAddress.split(":")
Expand All @@ -143,17 +152,17 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
progressThread = new GlobalWorkerRpcThread(globalWorker, this)
progressThread.start()

allocatedClientThreads = new Array[UcxWorkerThread](ucxShuffleConf.numWorkers)
allocatedClientWorkers = new Array[UcxWorkerWrapper](ucxShuffleConf.numWorkers)
logInfo(s"Allocating ${ucxShuffleConf.numWorkers} client workers")
for (i <- 0 until ucxShuffleConf.numWorkers) {
val clientId: Long = ((i.toLong + 1L) << 32) | executorId
ucpWorkerParams.setClientId(clientId)
val worker = ucxContext.newWorker(ucpWorkerParams)
val workerWrapper = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId)
allocatedClientThreads(i) = new UcxWorkerThread(workerWrapper)
allocatedClientThreads(i).start()
allocatedClientWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId)
}

allocatedServerWorkers.foreach(_.progressStart())
allocatedClientWorkers.foreach(_.progressStart())
initialized = true
logInfo(s"Started listener on ${listener.getAddress}")
SerializationUtils.serializeInetAddress(listener.getAddress)
Expand All @@ -169,7 +178,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo

hostBounceBufferMemoryPool.close()

allocatedClientThreads.foreach(_.close)
allocatedClientWorkers.foreach(_.close())

if (listener != null) {
listener.close()
Expand All @@ -186,7 +195,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
globalWorker = null
}

allocatedServerThreads.foreach(_.close)
allocatedServerWorkers.foreach(_.close())

if (ucxContext != null) {
ucxContext.close()
Expand All @@ -200,11 +209,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
* connection establishment outside of UcxShuffleManager.
*/
override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = {
executorAddresses.put(executorId, workerAddress)
allocatedClientThreads.foreach(t => {
t.workerWrapper.getConnection(executorId)
t.workerWrapper.progressConnect()
})
allocatedClientWorkers.foreach(_.getConnection(executorId))
}

def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = {
Expand All @@ -214,7 +219,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
}

def preConnect(): Unit = {
allocatedClientThreads.foreach(_.workerWrapper.preconnect())
allocatedClientWorkers.foreach(_.preconnect())
}

/**
Expand Down Expand Up @@ -273,51 +278,39 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId],
resultBufferAllocator: BufferAllocator,
callbacks: Seq[OperationCallback]): Unit = {
val client = selectClientThread
client.submit(new Runnable {
override def run = client.workerWrapper.fetchBlocksByBlockIds(
executorId, blockIds, resultBufferAllocator, callbacks)
})
selectClientWorker.fetchBlocksByBlockIds(executorId, blockIds,
resultBufferAllocator, callbacks)
}

def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = {
executorAddresses.put(executorId, workerAddress)
allocatedServerThreads.foreach(t => {
t.workerWrapper.connectByWorkerAddress(executorId, workerAddress)
})
allocatedServerWorkers.foreach(
_.connectByWorkerAddress(executorId, workerAddress))
}

def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = {
val server = selectServerThread
server.submit(new Runnable {
override def run = {
val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt)
val blockIds = mutable.ArrayBuffer.empty[BlockId]

// 1. Deserialize blockIds from header
while (buffer.remaining() > 0) {
val blockId = UcxShuffleBockId.deserialize(buffer)
if (!registeredBlocks.contains(blockId)) {
throw new UcxException(s"$blockId is not registered")
}
blockIds += blockId
}

def handleFetchBlockRequest(replyTag: Int, blockIds: Seq[BlockId],
replyExecutor: Long): Unit = {
replyThreadPool.submit(new Runnable {
override def run(): Unit = {
val blocks = blockIds.map(bid => registeredBlocks(bid))
amData.close()

server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor)
selectServerWorker.handleFetchBlockRequest(blocks, replyTag,
replyExecutor)
}
})
}

@inline
def selectClientThread(): UcxWorkerThread = allocatedClientThreads(
(Thread.currentThread().getId % allocatedClientThreads.length).toInt)
def selectClientWorker(): UcxWorkerWrapper = allocatedClientWorkers(
(clientWorkerId.incrementAndGet() % allocatedClientWorkers.length).abs)

@inline
def selectServerThread(): UcxWorkerThread = allocatedServerThreads(
(Thread.currentThread().getId % allocatedServerThreads.length).toInt)
def selectServerWorker(): UcxWorkerWrapper = Option(serverLocal.get) match {
case Some(server) => server
case None =>
val server = allocatedServerWorkers(
(serverWorkerId.incrementAndGet() % allocatedServerWorkers.length).abs)
serverLocal.set(server)
server
}

/**
* Progress outstanding operations. This routine is blocking (though may poll for event).
Expand All @@ -329,3 +322,18 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
override def progress(): Unit = {
}
}

private[ucx] class UcxSucceedOperationResult(
mem: MemoryBlock, stats: OperationStats) extends OperationResult {
override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS

override def getError: TransportError = null

override def getStats: Option[OperationStats] = Option(stats)

override def getData: MemoryBlock = mem
}

private[ucx] class UcxFetchState(val callbacks: Seq[OperationCallback],
val request: UcxRequest,
val timestamp: Long) {}
Loading

0 comments on commit bb17180

Please sign in to comment.