Skip to content

Commit

Permalink
select worker round robin
Browse files Browse the repository at this point in the history
  • Loading branch information
JeynmannZ committed Jun 3, 2024
1 parent 53ce91c commit 6610228
Showing 1 changed file with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
val executorAddresses = new TrieMap[ExecutorId, ByteBuffer]

private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _
private val clientWorkerId = new AtomicInteger()

private var allocatedServerWorkers: Array[UcxWorkerWrapper] = _
private val serverWorkerId = new AtomicInteger()
private val serverLocal = new ThreadLocal[UcxWorkerWrapper]

private val registeredBlocks = new TrieMap[BlockId, Block]
private var progressThread: Thread = _
Expand Down Expand Up @@ -296,11 +300,17 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo

@inline
def selectClientWorker(): UcxWorkerWrapper = allocatedClientWorkers(
(Thread.currentThread().getId % allocatedClientWorkers.length).toInt)
(clientWorkerId.incrementAndGet() % allocatedClientWorkers.length).abs)

@inline
def selectServerWorker(): UcxWorkerWrapper = allocatedServerWorkers(
(Thread.currentThread().getId % allocatedServerWorkers.length).toInt)
def selectServerWorker(): UcxWorkerWrapper = Option(serverLocal.get) match {
case Some(server) => server
case None =>
val server = allocatedServerWorkers(
(serverWorkerId.incrementAndGet() % allocatedServerWorkers.length).abs)
serverLocal.set(server)
server
}

/**
* Progress outstanding operations. This routine is blocking (though may poll for event).
Expand Down

0 comments on commit 6610228

Please sign in to comment.