diff --git a/generator/templates/art/api_kotlin_async_helpers.kt b/generator/templates/art/api_kotlin_async_helpers.kt index afa3529334..3698661464 100644 --- a/generator/templates/art/api_kotlin_async_helpers.kt +++ b/generator/templates/art/api_kotlin_async_helpers.kt @@ -30,8 +30,8 @@ import kotlin.coroutines.resume import kotlin.coroutines.suspendCoroutine {% from 'art/api_kotlin_types.kt' import kotlin_declaration, kotlin_definition with context %} -//* We make a return class for every function pointer so that usage of callback-using methods can be -// replaced with suspend (async) function that returns the same data. +//* Legacy callback pattern: we make a return class for every function pointer so that usage of +//* callback-using methods can be replaced with suspend (async) function that returns the same data. {% for function_pointer in by_category['function pointer'] if len(function_pointer.name.chunks) > 1 %} //* Function pointers generally end in Callback which we replace with Return. @@ -42,8 +42,8 @@ import kotlin.coroutines.suspendCoroutine {% endfor %}) {% endfor %} -//* Every method that is identified as using callbacks is given a helper method that wraps the -//* call with a suspend function. +//* Legacy callback pattern: every method that is identified as using callbacks is given a helper +//* method that wraps the call with a suspend function. {% for obj in by_category['object'] %} {% for method in obj.methods if is_async_method(method) %} {% set function_pointer = method.arguments[-2].type %} @@ -67,3 +67,50 @@ import kotlin.coroutines.suspendCoroutine } {% endfor %} {% endfor %} + +//* Provide an async wrapper for the 'callback info' type of async methods. +{% for obj in by_category['object'] %} + {% for method in obj.methods if has_callbackInfoStruct(method) %} + {% set callback_info = method.arguments[-1].type %} + {% set callback_function = callback_info.members[-1].type %} + {% set return_name = callback_function.name.chunks[:-1] | map('title') | join + 'Return' %} + + //* We make a return class for every callback method so that it can be used inline + //* (without callbacks) in a suspend (async) function. + public data class {{ return_name }}( + {% for arg in kotlin_record_members(callback_function.arguments) %} + val {{ as_varName(arg.name) }}: {{ kotlin_declaration(arg) }}, + {% endfor %}) + + //* Every method that is identified as using callbacks is given a helper method that wraps + //* call with a suspend function. + public suspend fun {{ obj.name.CamelCase() }}.{{ method.name.camelCase() }}( + {%- for arg in method.arguments[:-1] %} + {{- as_varName(arg.name) }}: {{ kotlin_definition(arg) }}, + {%- endfor %}): {{ return_name }} = suspendCoroutine { + {{ method.name.camelCase() }}( + {%- for arg in method.arguments %} + {{- as_varName(arg.name) }} + {%- if loop.last %} + //* The final parameter of a callback method is always callback info. + //* We make this and include our generated callback. + {{- ' = ' }} + {{- callback_info.name.CamelCase() }}(CallbackMode.AllowSpontaneous) + {%- else %} + //* Non-final parameters are whatever the client supplied. + {{- ', ' }} + {%- endif %} + {%- endfor %}{ + {%- for arg in kotlin_record_members(callback_function.arguments) %} + {{- as_varName(arg.name) }}, + {%- endfor %} -> it.resume({{ return_name }}( + //* We make an instance of the callback parameters -> return type wrapper. + {%- for arg in kotlin_record_members(callback_function.arguments) %} + {{- as_varName(arg.name) }} {{ ', ' }} + {%- endfor %}) + ) + }) + } + {% endfor %} +{% endfor %} + diff --git a/tools/android/webgpu/src/androidTest/java/android/dawn/AsyncHelperTest.kt b/tools/android/webgpu/src/androidTest/java/android/dawn/AsyncHelperTest.kt new file mode 100644 index 0000000000..0be0988308 --- /dev/null +++ b/tools/android/webgpu/src/androidTest/java/android/dawn/AsyncHelperTest.kt @@ -0,0 +1,66 @@ +import android.dawn.* +import androidx.test.ext.junit.runners.AndroidJUnit4 +import org.junit.Test +import org.junit.runner.RunWith + +@RunWith(AndroidJUnit4::class) +class AsyncHelperTest { + @Test + fun asyncMethodTest() { + dawnTestLauncher() { device -> + /* Set up a shader module to support the async call. */ + val shaderModule = device.createShaderModule( + ShaderModuleDescriptor(shaderSourceWGSL = ShaderSourceWGSL("")) + ) + + /* Call an asynchronous method, converted from a callback pattern by a helper. */ + val result = device.createRenderPipelineAsync( + RenderPipelineDescriptor(vertex = VertexState(module = shaderModule)) + ) + + assert(result.status == CreatePipelineAsyncStatus.ValidationError) { + """Create render pipeline (async) should fail when no shader entry point exists. + The result was: ${result.status}""" + } + } + } + + @Test + fun asyncMethodTestValidationPasses() { + dawnTestLauncher() { device -> + /* Set up a shader module to support the async call. */ + val shaderModule = device.createShaderModule( + ShaderModuleDescriptor( + shaderSourceWGSL = ShaderSourceWGSL( + """ +@vertex fn vertexMain(@builtin(vertex_index) i : u32) -> +@builtin(position) vec4f { + return vec4f(); +} +@fragment fn fragmentMain() -> @location(0) vec4f { + return vec4f(); +} + """ + ) + ) + ) + + /* Call an asynchronous method, converted from a callback pattern by a helper. */ + val result = device.createRenderPipelineAsync( + RenderPipelineDescriptor( + vertex = VertexState(module = shaderModule), + fragment = FragmentState( + module = shaderModule, + targets = arrayOf(ColorTargetState(format = TextureFormat.RGBA8Unorm)) + ) + ) + ) + + assert(result.status == CreatePipelineAsyncStatus.Success) { + """Create render pipeline (async) should pass with a simple shader. + The result was: ${result.status} + The message was: ${result.message}""" + } + } + } +}