From 66102287cf74fe3273d2558062c7baced2a554af Mon Sep 17 00:00:00 2001 From: zizhao Date: Mon, 17 Jul 2023 07:06:08 +0300 Subject: [PATCH] select worker round robin --- .../spark/shuffle/ucx/UcxShuffleTransport.scala | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index eccd281e..87c8c863 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -86,7 +86,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _ + private val clientWorkerId = new AtomicInteger() + private var allocatedServerWorkers: Array[UcxWorkerWrapper] = _ + private val serverWorkerId = new AtomicInteger() + private val serverLocal = new ThreadLocal[UcxWorkerWrapper] private val registeredBlocks = new TrieMap[BlockId, Block] private var progressThread: Thread = _ @@ -296,11 +300,17 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo @inline def selectClientWorker(): UcxWorkerWrapper = allocatedClientWorkers( - (Thread.currentThread().getId % allocatedClientWorkers.length).toInt) + (clientWorkerId.incrementAndGet() % allocatedClientWorkers.length).abs) @inline - def selectServerWorker(): UcxWorkerWrapper = allocatedServerWorkers( - (Thread.currentThread().getId % allocatedServerWorkers.length).toInt) + def selectServerWorker(): UcxWorkerWrapper = Option(serverLocal.get) match { + case Some(server) => server + case None => + val server = allocatedServerWorkers( + (serverWorkerId.incrementAndGet() % allocatedServerWorkers.length).abs) + serverLocal.set(server) + server + } /** * Progress outstanding operations. This routine is blocking (though may poll for event).