From 55044195c8d6bdfdadcd7411aa6ebe9e97bcf045 Mon Sep 17 00:00:00 2001 From: Arman Bilge Date: Sun, 6 Aug 2023 00:29:53 +0000 Subject: [PATCH] Elide thunk allocation when using `sleepInternal` --- .../effect/unsafe/WorkStealingThreadPool.scala | 2 +- .../scala/cats/effect/unsafe/TimerSkipList.scala | 7 +++++-- .../effect/unsafe/WorkStealingThreadPool.scala | 6 ++++-- .../scala/cats/effect/unsafe/WorkerThread.scala | 4 +++- .../src/main/scala/cats/effect/IOFiber.scala | 16 ++++++++++------ 5 files changed, 23 insertions(+), 12 deletions(-) diff --git a/core/js-native/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala b/core/js-native/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala index e8118e1cee..f47fc7889a 100644 --- a/core/js-native/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala +++ b/core/js-native/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala @@ -30,7 +30,7 @@ private[effect] sealed abstract class WorkStealingThreadPool private () private[effect] def reschedule(runnable: Runnable): Unit private[effect] def sleepInternal( delay: FiniteDuration, - callback: Right[Nothing, Unit] => Unit): Runnable + callback: Right[Nothing, Unit] => Unit): Function0[Unit] with Runnable private[effect] def sleep( delay: FiniteDuration, task: Runnable, diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/TimerSkipList.scala b/core/jvm/src/main/scala/cats/effect/unsafe/TimerSkipList.scala index 5cd5cd884a..e02a2422b9 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/TimerSkipList.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/TimerSkipList.scala @@ -68,18 +68,21 @@ private final class TimerSkipList() extends AtomicLong(MARKER + 1L) { sequenceNu cb: Callback, next: Node ) extends TimerSkipListNodeBase[Callback, Node](cb, next) + with Function0[Unit] with Runnable { /** * Cancels the timer */ - final override def run(): Unit = { + final def apply(): Unit = { // TODO: We could null the callback here directly, // TODO: and the do the lookup after (for unlinking). TimerSkipList.this.doRemove(triggerTime, sequenceNum) () } + final def run() = apply() + private[TimerSkipList] final def isMarker: Boolean = { // note: a marker node also has `triggerTime == MARKER`, // but that's also a valid trigger time, so we need @@ -158,7 +161,7 @@ private final class TimerSkipList() extends AtomicLong(MARKER + 1L) { sequenceNu delay: Long, callback: Right[Nothing, Unit] => Unit, tlr: ThreadLocalRandom - ): Runnable = { + ): Function0[Unit] with Runnable = { require(delay >= 0L) // we have to check for overflow: val triggerTime = computeTriggerTime(now = now, delay = delay) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala index 909746243d..cd740b0b41 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala @@ -621,7 +621,9 @@ private[effect] final class WorkStealingThreadPool( /** * Tries to call the current worker's `sleep`, but falls back to `sleepExternal` if needed. */ - def sleepInternal(delay: FiniteDuration, callback: Right[Nothing, Unit] => Unit): Runnable = { + def sleepInternal( + delay: FiniteDuration, + callback: Right[Nothing, Unit] => Unit): Function0[Unit] with Runnable = { val thread = Thread.currentThread() if (thread.isInstanceOf[WorkerThread]) { val worker = thread.asInstanceOf[WorkerThread] @@ -642,7 +644,7 @@ private[effect] final class WorkStealingThreadPool( */ private[this] final def sleepExternal( delay: FiniteDuration, - callback: Right[Nothing, Unit] => Unit): Runnable = { + callback: Right[Nothing, Unit] => Unit): Function0[Unit] with Runnable = { val random = ThreadLocalRandom.current() val idx = random.nextInt(threadCount) val tsl = sleepers(idx) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala index 31b36e408e..849e71a2d4 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala @@ -152,7 +152,9 @@ private final class WorkerThread( } } - def sleep(delay: FiniteDuration, callback: Right[Nothing, Unit] => Unit): Runnable = { + def sleep( + delay: FiniteDuration, + callback: Right[Nothing, Unit] => Unit): Function0[Unit] with Runnable = { // take the opportunity to update the current time, just in case other timers can benefit val _now = System.nanoTime() now = _now diff --git a/core/shared/src/main/scala/cats/effect/IOFiber.scala b/core/shared/src/main/scala/cats/effect/IOFiber.scala index ef4957bbaf..f6214b50f8 100644 --- a/core/shared/src/main/scala/cats/effect/IOFiber.scala +++ b/core/shared/src/main/scala/cats/effect/IOFiber.scala @@ -925,13 +925,17 @@ private final class IOFiber[A]( IO { val scheduler = runtime.scheduler - val cancel = - if (scheduler.isInstanceOf[WorkStealingThreadPool]) - scheduler.asInstanceOf[WorkStealingThreadPool].sleepInternal(delay, cb) - else - scheduler.sleep(delay, () => cb(RightUnit)) + val cancelIO = + if (scheduler.isInstanceOf[WorkStealingThreadPool]) { + val cancel = + scheduler.asInstanceOf[WorkStealingThreadPool].sleepInternal(delay, cb) + IO.Delay(cancel, null) + } else { + val cancel = scheduler.sleep(delay, () => cb(RightUnit)) + IO(cancel.run()) + } - Some(IO(cancel.run())) + Some(cancelIO) } } else IO.cede