From bb1718001a0a66105e64ab8a0ca873d02f8f8a0e Mon Sep 17 00:00:00 2001 From: Zihao Zhao Date: Thu, 30 May 2024 07:53:28 +0800 Subject: [PATCH] add thread pool + progress thread --- .../shuffle/ucx/UcxShuffleTransport.scala | 110 ++++++----- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 181 ++++++++---------- .../ucx/rpc/GlobalWorkerRpcThread.scala | 20 +- 3 files changed, 156 insertions(+), 155 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 3f8a3df3..12cafc67 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -9,6 +9,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.memory.UcxHostBounceBuffersPool import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils} +import org.apache.spark.util.ThreadUtils import org.apache.spark.shuffle.utils.UnsafeUtils import org.openucx.jucx.UcxException import org.openucx.jucx.ucp._ @@ -16,6 +17,7 @@ import org.openucx.jucx.ucs.UcsConstants import java.net.InetSocketAddress import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.collection.mutable @@ -53,7 +55,7 @@ class UcxStats extends OperationStats { } case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { - override def serializedSize: Int = 12 + override def serializedSize: Int = UcxShuffleBockId.serializedSize override def serialize(byteBuffer: ByteBuffer): Unit = { byteBuffer.putInt(shuffleId) @@ -63,6 +65,8 @@ case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends B } object UcxShuffleBockId { + val serializedSize: Int = 12 + def deserialize(byteBuffer: ByteBuffer): UcxShuffleBockId = { val shuffleId = byteBuffer.getInt val mapId = byteBuffer.getInt @@ -81,13 +85,20 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val endpoints = mutable.Set.empty[UcpEndpoint] val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] - private var allocatedClientThreads: Array[UcxWorkerThread] = _ - private var allocatedServerThreads: Array[UcxWorkerThread] = _ + private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _ + private var clientWorkerId = new AtomicInteger() + + private var allocatedServerWorkers: Array[UcxWorkerWrapper] = _ + private val serverWorkerId = new AtomicInteger() + private var serverLocal = new ThreadLocal[UcxWorkerWrapper] private val registeredBlocks = new TrieMap[BlockId, Block] private var progressThread: Thread = _ var hostBounceBufferMemoryPool: UcxHostBounceBuffersPool = _ + private[spark] lazy val replyThreadPool = ThreadUtils.newForkJoinPool( + "UcxListenerThread", ucxShuffleConf.numListenerThreads) + private val errorHandler = new UcpEndpointErrorHandler { override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = { if (errorCode == UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) { @@ -122,13 +133,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo globalWorker = ucxContext.newWorker(ucpWorkerParams) hostBounceBufferMemoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) - allocatedServerThreads = new Array[UcxWorkerThread](ucxShuffleConf.numListenerThreads) + allocatedServerWorkers = new Array[UcxWorkerWrapper](ucxShuffleConf.numListenerThreads) logInfo(s"Allocating ${ucxShuffleConf.numListenerThreads} server workers") for (i <- 0 until ucxShuffleConf.numListenerThreads) { val worker = ucxContext.newWorker(ucpWorkerParams) - val workerWrapper = UcxWorkerWrapper(worker, this, isClientWorker = false, i.toLong) - allocatedServerThreads(i) = new UcxWorkerThread(workerWrapper) - allocatedServerThreads(i).start() + allocatedServerWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = false, i.toLong) } val Array(host, port) = ucxShuffleConf.listenerAddress.split(":") @@ -143,17 +152,17 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo progressThread = new GlobalWorkerRpcThread(globalWorker, this) progressThread.start() - allocatedClientThreads = new Array[UcxWorkerThread](ucxShuffleConf.numWorkers) + allocatedClientWorkers = new Array[UcxWorkerWrapper](ucxShuffleConf.numWorkers) logInfo(s"Allocating ${ucxShuffleConf.numWorkers} client workers") for (i <- 0 until ucxShuffleConf.numWorkers) { val clientId: Long = ((i.toLong + 1L) << 32) | executorId ucpWorkerParams.setClientId(clientId) val worker = ucxContext.newWorker(ucpWorkerParams) - val workerWrapper = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) - allocatedClientThreads(i) = new UcxWorkerThread(workerWrapper) - allocatedClientThreads(i).start() + allocatedClientWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) } + allocatedServerWorkers.foreach(_.progressStart()) + allocatedClientWorkers.foreach(_.progressStart()) initialized = true logInfo(s"Started listener on ${listener.getAddress}") SerializationUtils.serializeInetAddress(listener.getAddress) @@ -169,7 +178,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo hostBounceBufferMemoryPool.close() - allocatedClientThreads.foreach(_.close) + allocatedClientWorkers.foreach(_.close()) if (listener != null) { listener.close() @@ -186,7 +195,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo globalWorker = null } - allocatedServerThreads.foreach(_.close) + allocatedServerWorkers.foreach(_.close()) if (ucxContext != null) { ucxContext.close() @@ -200,11 +209,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo * connection establishment outside of UcxShuffleManager. */ override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { - executorAddresses.put(executorId, workerAddress) - allocatedClientThreads.foreach(t => { - t.workerWrapper.getConnection(executorId) - t.workerWrapper.progressConnect() - }) + allocatedClientWorkers.foreach(_.getConnection(executorId)) } def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = { @@ -214,7 +219,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo } def preConnect(): Unit = { - allocatedClientThreads.foreach(_.workerWrapper.preconnect()) + allocatedClientWorkers.foreach(_.preconnect()) } /** @@ -273,51 +278,39 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: BufferAllocator, callbacks: Seq[OperationCallback]): Unit = { - val client = selectClientThread - client.submit(new Runnable { - override def run = client.workerWrapper.fetchBlocksByBlockIds( - executorId, blockIds, resultBufferAllocator, callbacks) - }) + selectClientWorker.fetchBlocksByBlockIds(executorId, blockIds, + resultBufferAllocator, callbacks) } def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { - executorAddresses.put(executorId, workerAddress) - allocatedServerThreads.foreach(t => { - t.workerWrapper.connectByWorkerAddress(executorId, workerAddress) - }) + allocatedServerWorkers.foreach( + _.connectByWorkerAddress(executorId, workerAddress)) } - def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = { - val server = selectServerThread - server.submit(new Runnable { - override def run = { - val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) - val blockIds = mutable.ArrayBuffer.empty[BlockId] - - // 1. Deserialize blockIds from header - while (buffer.remaining() > 0) { - val blockId = UcxShuffleBockId.deserialize(buffer) - if (!registeredBlocks.contains(blockId)) { - throw new UcxException(s"$blockId is not registered") - } - blockIds += blockId - } - + def handleFetchBlockRequest(replyTag: Int, blockIds: Seq[BlockId], + replyExecutor: Long): Unit = { + replyThreadPool.submit(new Runnable { + override def run(): Unit = { val blocks = blockIds.map(bid => registeredBlocks(bid)) - amData.close() - - server.workerWrapper.handleFetchBlockRequest(blocks, replyTag, replyExecutor) + selectServerWorker.handleFetchBlockRequest(blocks, replyTag, + replyExecutor) } }) } @inline - def selectClientThread(): UcxWorkerThread = allocatedClientThreads( - (Thread.currentThread().getId % allocatedClientThreads.length).toInt) + def selectClientWorker(): UcxWorkerWrapper = allocatedClientWorkers( + (clientWorkerId.incrementAndGet() % allocatedClientWorkers.length).abs) @inline - def selectServerThread(): UcxWorkerThread = allocatedServerThreads( - (Thread.currentThread().getId % allocatedServerThreads.length).toInt) + def selectServerWorker(): UcxWorkerWrapper = Option(serverLocal.get) match { + case Some(server) => server + case None => + val server = allocatedServerWorkers( + (serverWorkerId.incrementAndGet() % allocatedServerWorkers.length).abs) + serverLocal.set(server) + server + } /** * Progress outstanding operations. This routine is blocking (though may poll for event). @@ -329,3 +322,18 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo override def progress(): Unit = { } } + +private[ucx] class UcxSucceedOperationResult( + mem: MemoryBlock, stats: OperationStats) extends OperationResult { + override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS + + override def getError: TransportError = null + + override def getStats: Option[OperationStats] = Option(stats) + + override def getData: MemoryBlock = mem +} + +private[ucx] class UcxFetchState(val callbacks: Seq[OperationCallback], + val request: UcxRequest, + val timestamp: Long) {} \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 723cfe9e..f945c5e7 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -62,15 +62,19 @@ class UcxRefCountMemoryBlock(baseBlock: MemoryBlock, offset: Long, size: Long, case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, isClientWorker: Boolean, id: Long = 0L) extends Closeable with Logging { - - private final val connections = new TrieMap[transport.ExecutorId, UcpEndpoint] - private val requestData = new TrieMap[Int, (Seq[OperationCallback], UcxRequest, transport.BufferAllocator)] + private[ucx] final val timeout = transport.ucxShuffleConf.getSparkConf.getTimeAsSeconds( + "spark.network.timeout", "120s") * 1000 + private[ucx] final val connections = new TrieMap[transport.ExecutorId, UcpEndpoint] + private[ucx] lazy val requestData = new TrieMap[Int, UcxFetchState] private val tag = new AtomicInteger(Random.nextInt()) - private val flushRequests = new ConcurrentLinkedQueue[UcpRequest]() - private val ioThreadPool = ThreadUtils.newForkJoinPool("IO threads", + private[ucx] lazy val ioThreadOn = transport.ucxShuffleConf.numIoThreads > 1 + private[ucx] lazy val ioThreadPool = ThreadUtils.newForkJoinPool("IO threads", transport.ucxShuffleConf.numIoThreads) - private val ioTaskSupport = new ForkJoinTaskSupport(ioThreadPool) + private[ucx] lazy val ioTaskSupport = new ForkJoinTaskSupport(ioThreadPool) + private[ucx] var progressThread: Thread = _ + + private[ucx] lazy val memPool = transport.hostBounceBufferMemoryPool if (isClientWorker) { // Receive block data handler @@ -84,7 +88,9 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i throw new UcxException(s"No data for tag $i.") } - val (callbacks, request, allocator) = data.get + val fetchState = data.get + val callbacks = fetchState.callbacks + val request = fetchState.request val stats = request.getStats.get.asInstanceOf[UcxStats] stats.receiveSize = ucpAmData.getLength @@ -116,14 +122,14 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } if (callbacks.isEmpty) UcsConstants.STATUS.UCS_OK else UcsConstants.STATUS.UCS_INPROGRESS } else { - val mem = allocator(ucpAmData.getLength) + val mem = memPool.get(ucpAmData.getLength) stats.amHandleTime = System.nanoTime() request.setRequest(worker.recvAmDataNonBlocking(ucpAmData.getDataHandle, mem.address, ucpAmData.getLength, new UcxCallback() { override def onSuccess(r: UcpRequest): Unit = { request.completed = true stats.endTime = System.nanoTime() - logDebug(s"Received rndv data of size: ${mem.size} for tag $i in " + + logDebug(s"Received rndv data of size: ${ucpAmData.getLength} for tag $i in " + s"${stats.getElapsedTimeNs} ns " + s"time from amHandle: ${System.nanoTime() - stats.amHandleTime} ns") for (b <- 0 until numBlocks) { @@ -148,26 +154,35 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } override def close(): Unit = { - val closeRequests = connections.map { - case (_, endpoint) => endpoint.closeNonBlockingForce() - } - while (!closeRequests.forall(_.isCompleted)) { - progress() + if (isClientWorker) { + val closeRequests = connections.map { + case (_, endpoint) => endpoint.closeNonBlockingForce() + } + while (!closeRequests.forall(_.isCompleted)) { + progress() + } } - ioThreadPool.shutdown() connections.clear() + if (progressThread != null) { + progressThread.interrupt() + progressThread.join(1) + } + if (ioThreadOn) { + ioThreadPool.shutdown() + } worker.close() } + def progressStart(): Unit = { + progressThread = new ProgressThread(s"UCX-progress-$id", worker, + transport.ucxShuffleConf.useWakeup) + progressThread.start() + } + /** * Blocking progress until there's outstanding flush requests. */ def progressConnect(): Unit = { - while (!flushRequests.isEmpty) { - progress() - flushRequests.removeIf(_.isCompleted) - } - logTrace(s"Flush completed. Number of connections: ${connections.keys.size}") } /** @@ -196,15 +211,18 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i def getConnection(executorId: transport.ExecutorId): UcpEndpoint = { - val startTime = System.currentTimeMillis() - while (!transport.executorAddresses.contains(executorId)) { - if (System.currentTimeMillis() - startTime > - transport.ucxShuffleConf.getSparkConf.getTimeAsMs("spark.network.timeout", "100")) { - throw new UcxException(s"Don't get a worker address for $executorId") + if (!connections.contains(executorId)) { + if (!transport.executorAddresses.contains(executorId)) { + val startTime = System.currentTimeMillis() + while (!transport.executorAddresses.contains(executorId)) { + if (System.currentTimeMillis() - startTime > timeout) { + throw new UcxException(s"Don't get a worker address for $executorId") + } + } } } - connections.getOrElseUpdate(executorId, { + connections.getOrElseUpdate(executorId, { val address = transport.executorAddresses(executorId) val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode() .setSocketAddress(SerializationUtils.deserializeInetAddress(address)).sendClientId() @@ -215,8 +233,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } }).setName(s"Endpoint to $executorId") - logDebug(s"Worker $this connecting to Executor($executorId, " + - s"${SerializationUtils.deserializeInetAddress(address)}") + logDebug(s"Worker ${id.toInt}:${id>>32} connecting to Executor($executorId)") worker.synchronized { val ep = worker.newEndpoint(endpointParams) val header = Platform.allocateDirectBuffer(UnsafeUtils.LONG_SIZE) @@ -231,8 +248,10 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i header.clear() workerAddress.clear() } + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to send $errorMsg") + } }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - flushRequests.add(ep.flushNonBlocking(null)) ep } }) @@ -243,7 +262,6 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i callbacks: Seq[OperationCallback]): Unit = { val startTime = System.nanoTime() val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.LONG_SIZE - val ep = getConnection(executorId) val t = tag.incrementAndGet() @@ -253,29 +271,31 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i blockIds.foreach(b => b.serialize(buffer)) val request = new UcxRequest(null, new UcxStats()) - requestData.put(t, (callbacks, request, resultBufferAllocator)) + requestData.put(t, new UcxFetchState(callbacks, request, startTime)) buffer.rewind() val address = UnsafeUtils.getAdress(buffer) val dataAddress = address + headerSize - ep.sendAmNonBlocking(0, address, - headerSize, dataAddress, buffer.capacity() - headerSize, - UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - buffer.clear() - logDebug(s"Sent message on $ep to $executorId to fetch ${blockIds.length} blocks on tag $t id $id" + - s"in ${System.nanoTime() - startTime} ns") - } - }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + val ep = getConnection(executorId) + worker.synchronized { + ep.sendAmNonBlocking(0, address, + headerSize, dataAddress, buffer.capacity() - headerSize, + UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + buffer.clear() + logDebug(s"Sent message on $ep to $executorId to fetch ${blockIds.length} blocks on tag $t id $id" + + s"in ${System.nanoTime() - startTime} ns") + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + } } def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): Unit = try { val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blocks.length - val resultMemory = transport.hostBounceBufferMemoryPool.get(tagAndSizes + blocks.map(_.getSize).sum) - .asInstanceOf[UcxBounceBufferMemoryBlock] - val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, - resultMemory.size) + val msgSize = tagAndSizes + blocks.map(_.getSize).sum + val resultMemory = memPool.get(msgSize).asInstanceOf[UcxBounceBufferMemoryBlock] + val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, msgSize) resultBuffer.putInt(replyTag) var offset = 0 @@ -289,7 +309,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i localBuffer } // Do parallel read of blocks - val blocksCollection = if (transport.ucxShuffleConf.numIoThreads > 1) { + val blocksCollection = if (ioThreadOn) { val parCollection = blocks.indices.par parCollection.tasksupport = ioTaskSupport parCollection @@ -302,73 +322,42 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } val startTime = System.nanoTime() - getConnection(replyExecutor).sendAmNonBlocking(1, resultMemory.address, tagAndSizes, - resultMemory.address + tagAndSizes, resultMemory.size - tagAndSizes, 0, new UcxCallback { - override def onSuccess(request: UcpRequest): Unit = { - logTrace(s"Sent ${blocks.length} blocks of size: ${resultMemory.size} " + - s"to tag $replyTag in ${System.nanoTime() - startTime} ns.") - transport.hostBounceBufferMemoryPool.put(resultMemory) - } + val ep = getConnection(replyExecutor) + worker.synchronized { + ep.sendAmNonBlocking(1, resultMemory.address, tagAndSizes, + resultMemory.address + tagAndSizes, msgSize - tagAndSizes, 0, new UcxCallback { + override def onSuccess(request: UcpRequest): Unit = { + logTrace(s"Sent ${blocks.length} blocks of size: ${msgSize} " + + s"to tag $replyTag in ${System.nanoTime() - startTime} ns.") + resultMemory.close() + } - override def onError(ucsStatus: Int, errorMsg: String): Unit = { - logError(s"Failed to send $errorMsg") - } - }, new UcpRequestParams().setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - .setMemoryHandle(resultMemory.memory)) + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to send $errorMsg") + resultMemory.close() + } + }, new UcpRequestParams().setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + .setMemoryHandle(resultMemory.memory)) + } } catch { case ex: Throwable => logError(s"Failed to read and send data: $ex") } } -class UcxWorkerThread(val workerWrapper: UcxWorkerWrapper) extends Thread with Logging { - val id = workerWrapper.id - val worker = workerWrapper.worker - val transport = workerWrapper.transport - val useWakeup = workerWrapper.transport.ucxShuffleConf.useWakeup - - private val taskQueue = new ConcurrentLinkedQueue[Runnable]() - +private[ucx] class ProgressThread( + name: String, worker: UcpWorker, useWakeup: Boolean) extends Thread { setDaemon(true) - setName(s"UCX-worker $id") + setName(name) override def run(): Unit = { - logDebug(s"UCX-worker $id started") while (!isInterrupted) { - Option(taskQueue.poll()) match { - case Some(task) => task.run - case None => {} - } worker.synchronized { while (worker.progress() != 0) {} } - if(taskQueue.isEmpty && useWakeup) { + if (useWakeup) { worker.waitForEvents() } } - logDebug(s"UCX-worker $id stopped") - } - - @inline - def submit(task: Callable[_]): Future[_] = { - val future = new FutureTask(task) - taskQueue.offer(future) - worker.signal() - future - } - - @inline - def submit(task: Runnable): Future[Unit.type] = { - val future = new FutureTask(task, Unit) - taskQueue.offer(future) - worker.signal() - future - } - - @inline - def close(): Unit = { - interrupt() - join(10) - workerWrapper.close() } } \ No newline at end of file diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index d434542d..d8e42404 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -8,7 +8,7 @@ import java.nio.ByteBuffer import org.openucx.jucx.ucp.{UcpAmData, UcpConstants, UcpEndpoint, UcpWorker} import org.openucx.jucx.ucs.UcsConstants import org.apache.spark.internal.Logging -import org.apache.spark.shuffle.ucx.UcxShuffleTransport +import org.apache.spark.shuffle.ucx.{UcxShuffleTransport, UcxShuffleBockId} import org.apache.spark.shuffle.utils.UnsafeUtils import org.apache.spark.util.ThreadUtils @@ -22,9 +22,14 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransp val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) val replyTag = header.getInt val replyExecutor = header.getLong - transport.handleFetchBlockRequest(replyTag, amData, replyExecutor) - UcsConstants.STATUS.UCS_INPROGRESS - }, UcpConstants.UCP_AM_FLAG_PERSISTENT_DATA | UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) + val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, + amData.getLength.toInt) + val blockNum = buffer.remaining() / UcxShuffleBockId.serializedSize + val blockIds = (0 until blockNum).map( + _ => UcxShuffleBockId.deserialize(buffer)) + transport.handleFetchBlockRequest(replyTag, blockIds, replyExecutor) + UcsConstants.STATUS.UCS_OK + }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) // AM to get worker address for client worker and connect server workers to it @@ -40,13 +45,12 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransp override def run(): Unit = { if (transport.ucxShuffleConf.useWakeup) { while (!isInterrupted) { - if (globalWorker.progress() == 0) { - globalWorker.waitForEvents() - } + while (globalWorker.progress != 0) {} + globalWorker.waitForEvents() } } else { while (!isInterrupted) { - globalWorker.progress() + while (globalWorker.progress != 0) {} } } }