diff --git a/tests/test_async.py b/tests/test_async.py index a58e1fd9..2bcd3ad0 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,3 +1,4 @@ + import anyio import pytest @@ -28,7 +29,7 @@ def poll_function(): awaitable = WgpuAwaitable("test", callback, finalizer, poll_function) if use_async: - result = await awaitable.async_wait() + result = await awaitable else: result = awaitable.sync_wait() assert result == 10 * 10 diff --git a/wgpu/backends/wgpu_native/_api.py b/wgpu/backends/wgpu_native/_api.py index 6afffabc..df79b290 100644 --- a/wgpu/backends/wgpu_native/_api.py +++ b/wgpu/backends/wgpu_native/_api.py @@ -357,7 +357,7 @@ async def request_adapter_async( force_fallback_adapter=force_fallback_adapter, canvas=canvas, ) # no-cover - return await awaitable.async_wait() + return await awaitable def _request_adapter( self, *, power_preference=None, force_fallback_adapter=False, canvas=None @@ -873,7 +873,7 @@ async def request_device_async( ) # Note that although we claim this function is async, the callback always # happens inside the call to libf.wgpuAdapterRequestDevice - return await awaitable.async_wait() + return await awaitable def _request_device( self, @@ -1602,7 +1602,7 @@ def finalizer(id): self._internal, descriptor, callback, ffi.NULL ) - return await awaitable.async_wait() + return await awaitable def _create_compute_pipeline_descriptor( self, @@ -1703,7 +1703,7 @@ def finalizer(id): self._internal, descriptor, callback, ffi.NULL ) - return await awaitable.async_wait() + return await awaitable def _create_render_pipeline_descriptor( self, @@ -2079,7 +2079,7 @@ async def map_async( self, mode: flags.MapMode, offset: int = 0, size: Optional[int] = None ): awaitable = self._map(mode, offset, size) # for now - return await awaitable.async_wait() + return await awaitable def _map(self, mode, offset=0, size=None): sync_on_read = True @@ -3539,7 +3539,7 @@ def on_submitted_work_done_sync(self): async def on_submitted_work_done_async(self): awaitable = self._on_submitted_word_done() - await awaitable.async_wait() + await awaitable def _on_submitted_word_done(self): @ffi.callback("void(WGPUQueueWorkDoneStatus, void*)") diff --git a/wgpu/backends/wgpu_native/_helpers.py b/wgpu/backends/wgpu_native/_helpers.py index c2df819c..4b2669b8 100644 --- a/wgpu/backends/wgpu_native/_helpers.py +++ b/wgpu/backends/wgpu_native/_helpers.py @@ -251,30 +251,11 @@ def set_result(self, result): def set_error(self, error): self.result = (None, error) - def sync_wait(self): - if not self.poll_function: - if self.result is None: - raise RuntimeError("Expected callback to have already happened") - else: - while not self._is_done(): - time.sleep(self.SLEEP_TIME) - return self.finish() - - async def async_wait(self): - if not self.poll_function: - if self.result is None: - raise RuntimeError("Expected callback to have already happened") - else: - while not self._is_done(): - # A bug in anyio prevents us from waiting on an Event() - await anyio.sleep(self.SLEEP_TIME) - return self.finish() - def _is_done(self): self.poll_function() return self.result is not None or time.perf_counter() > self.maxtime - def finish(self): + def _finish(self): if not self.result: raise RuntimeError(f"Waiting for {self.title} timed out.") result, error = self.result @@ -283,6 +264,33 @@ def finish(self): else: return self.finalizer(result) + def sync_wait(self): + if self.result is not None: + pass + elif not self.poll_function: + raise RuntimeError("Expected callback to have already happened") + else: + while not self._is_done(): + time.sleep(self.SLEEP_TIME) + return self._finish() + + def async_wait(self): + return self + + def __await__(self): + # There is no documentation on what __await__() is supposed to return, but we + # can certainly copy from a function that *does* know what to return + async def wait_for_callback(): + if self.result is not None: + return + if not self.poll_function: + raise RuntimeError("Expected callback to have already happened") + while not self._is_done(): + await anyio.sleep(self.SLEEP_TIME) + + yield from wait_for_callback().__await__() + return self._finish() + class ErrorHandler: """Object that logs errors, with the option to collect incoming