diff --git a/crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl b/crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl index ae51f7a4ab8b3..4e368573738f8 100644 --- a/crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl +++ b/crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl @@ -45,51 +45,6 @@ struct DrawIndexedIndirect { base_instance: u32, } -struct PartialDerivatives { - barycentrics: vec3, - ddx: vec3, - ddy: vec3, -} - -// https://github.com/ConfettiFX/The-Forge/blob/2d453f376ef278f66f97cbaf36c0d12e4361e275/Examples_3/Visibility_Buffer/src/Shaders/FSL/visibilityBuffer_shade.frag.fsl#L83-L139 -fn compute_derivatives(vertex_clip_positions: array, 3>, ndc_uv: vec2, screen_size: vec2) -> PartialDerivatives { - var result: PartialDerivatives; - - let inv_w = 1.0 / vec3(vertex_clip_positions[0].w, vertex_clip_positions[1].w, vertex_clip_positions[2].w); - let ndc_0 = vertex_clip_positions[0].xy * inv_w[0]; - let ndc_1 = vertex_clip_positions[1].xy * inv_w[1]; - let ndc_2 = vertex_clip_positions[2].xy * inv_w[2]; - - let inv_det = 1.0 / determinant(mat2x2(ndc_2 - ndc_1, ndc_0 - ndc_1)); - result.ddx = vec3(ndc_1.y - ndc_2.y, ndc_2.y - ndc_0.y, ndc_0.y - ndc_1.y) * inv_det * inv_w; - result.ddy = vec3(ndc_2.x - ndc_1.x, ndc_0.x - ndc_2.x, ndc_1.x - ndc_0.x) * inv_det * inv_w; - - var ddx_sum = dot(result.ddx, vec3(1.0)); - var ddy_sum = dot(result.ddy, vec3(1.0)); - - let delta_v = ndc_uv - ndc_0; - let interp_inv_w = inv_w.x + delta_v.x * ddx_sum + delta_v.y * ddy_sum; - let interp_w = 1.0 / interp_inv_w; - - result.barycentrics = vec3( - interp_w * (delta_v.x * result.ddx.x + delta_v.y * result.ddy.x + inv_w.x), - interp_w * (delta_v.x * result.ddx.y + delta_v.y * result.ddy.y), - interp_w * (delta_v.x * result.ddx.z + delta_v.y * result.ddy.z), - ); - - result.ddx *= 2.0 / screen_size.x; - result.ddy *= 2.0 / screen_size.y; - ddx_sum *= 2.0 / screen_size.x; - ddy_sum *= 2.0 / screen_size.y; - - let interp_ddx_w = 1.0 / (interp_inv_w + ddx_sum); - let interp_ddy_w = 1.0 / (interp_inv_w + ddy_sum); - - result.ddx = interp_ddx_w * (result.barycentrics * interp_inv_w + result.ddx) - result.barycentrics; - result.ddy = interp_ddy_w * (result.barycentrics * interp_inv_w + result.ddy) - result.barycentrics; - return result; -} - @group(#{MESHLET_BIND_GROUP}) @binding(0) var meshlets: array; @group(#{MESHLET_BIND_GROUP}) @binding(1) var meshlet_instance_uniforms: array; @group(#{MESHLET_BIND_GROUP}) @binding(2) var meshlet_thread_instance_ids: array; diff --git a/crates/bevy_pbr/src/meshlet/meshlet_mesh_material.wgsl b/crates/bevy_pbr/src/meshlet/meshlet_mesh_material.wgsl index 11d4bbb6748ae..9bf1d8ae6dbff 100644 --- a/crates/bevy_pbr/src/meshlet/meshlet_mesh_material.wgsl +++ b/crates/bevy_pbr/src/meshlet/meshlet_mesh_material.wgsl @@ -1,11 +1,7 @@ #import bevy_pbr::{ - meshlet_bindings::{meshlet_visibility_buffer, meshlet_thread_meshlet_ids, meshlets, meshlet_vertex_ids, meshlet_vertex_data, meshlet_thread_instance_ids, meshlet_instance_uniforms, unpack_meshlet_vertex, compute_derivatives}, - mesh_functions::mesh_position_local_to_world, - mesh_types::MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT, - view_transformations::{uv_to_ndc, position_world_to_clip}, + view_transformations::uv_to_ndc, + meshlet_visibility_buffer_utils::load_vertex_output, } -#import bevy_pbr::mesh_view_bindings::view -#import bevy_render::maths::{affine_to_square, mat2x4_f32_to_mat3x3_unpack} fn rand_f(state: ptr) -> f32 { *state = *state * 747796405u + 2891336453u; @@ -23,39 +19,10 @@ fn vertex(@builtin(vertex_index) vertex_input: u32) -> @builtin(position) vec4) -> @location(0) vec4 { - let vbuffer = textureLoad(meshlet_visibility_buffer, vec2(clip_position.xy), 0).r; - let thread_id = vbuffer >> 8u; - let meshlet_id = meshlet_thread_meshlet_ids[thread_id]; - let meshlet = meshlets[meshlet_id]; - let triangle_id = extractBits(vbuffer, 0u, 8u); +fn fragment(@builtin(position) frag_coord: vec4) -> @location(0) vec4 { + let vertex_output = load_vertex_output(frag_coord); - let indices = meshlet.start_vertex_id + vec3(triangle_id * 3u) + vec3(0u, 1u, 2u); - let vertex_ids = vec3(meshlet_vertex_ids[indices.x], meshlet_vertex_ids[indices.y], meshlet_vertex_ids[indices.z]); - let vertex_1 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.x]); - let vertex_2 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.y]); - let vertex_3 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.z]); - - let instance_id = meshlet_thread_instance_ids[thread_id]; - let instance_uniform = meshlet_instance_uniforms[instance_id]; - let model = affine_to_square(instance_uniform.model); - - let world_position_1 = mesh_position_local_to_world(model, vec4(vertex_1.position, 1.0)); - let world_position_2 = mesh_position_local_to_world(model, vec4(vertex_2.position, 1.0)); - let world_position_3 = mesh_position_local_to_world(model, vec4(vertex_3.position, 1.0)); - let clip_position_1 = position_world_to_clip(world_position_1.xyz); - let clip_position_2 = position_world_to_clip(world_position_2.xyz); - let clip_position_3 = position_world_to_clip(world_position_3.xyz); - - let partial_derivatives = compute_derivatives( - array(clip_position_1, clip_position_2, clip_position_3), - clip_position.xy, - view.viewport.zw, - ); - - // TODO: Compute vertex output - - var rng = meshlet_id; + var rng = vertex_output.meshlet_id; let color = vec3(rand_f(&rng), rand_f(&rng), rand_f(&rng)); return vec4(color, 1.0); } diff --git a/crates/bevy_pbr/src/meshlet/mod.rs b/crates/bevy_pbr/src/meshlet/mod.rs index dce7989fa10c7..47853597ac0c9 100644 --- a/crates/bevy_pbr/src/meshlet/mod.rs +++ b/crates/bevy_pbr/src/meshlet/mod.rs @@ -48,8 +48,10 @@ use bevy_render::{ use bevy_transform::components::{GlobalTransform, Transform}; const MESHLET_BINDINGS_SHADER_HANDLE: Handle = Handle::weak_from_u128(1325134235233421); -pub(crate) const MESHLET_MESH_MATERIAL_SHADER_HANDLE: Handle = +const MESHLET_VISIBILITY_BUFFER_UTILS_SHADER_HANDLE: Handle = Handle::weak_from_u128(2325134235233421); +pub(crate) const MESHLET_MESH_MATERIAL_SHADER_HANDLE: Handle = + Handle::weak_from_u128(3325134235233421); pub struct MeshletPlugin; @@ -61,6 +63,12 @@ impl Plugin for MeshletPlugin { "meshlet_bindings.wgsl", Shader::from_wgsl ); + load_internal_asset!( + app, + MESHLET_VISIBILITY_BUFFER_UTILS_SHADER_HANDLE, + "visibility_buffer_utils.wgsl", + Shader::from_wgsl + ); load_internal_asset!( app, MESHLET_CULLING_SHADER_HANDLE, diff --git a/crates/bevy_pbr/src/meshlet/pipelines.rs b/crates/bevy_pbr/src/meshlet/pipelines.rs index 0c2ff8a8667b4..c46e633798497 100644 --- a/crates/bevy_pbr/src/meshlet/pipelines.rs +++ b/crates/bevy_pbr/src/meshlet/pipelines.rs @@ -9,11 +9,11 @@ use bevy_ecs::{ }; use bevy_render::render_resource::*; -pub const MESHLET_CULLING_SHADER_HANDLE: Handle = Handle::weak_from_u128(3325134235233421); +pub const MESHLET_CULLING_SHADER_HANDLE: Handle = Handle::weak_from_u128(4325134235233421); pub const MESHLET_VISIBILITY_BUFFER_SHADER_HANDLE: Handle = - Handle::weak_from_u128(4325134235233421); -pub const MESHLET_COPY_MATERIAL_DEPTH_SHADER_HANDLE: Handle = Handle::weak_from_u128(5325134235233421); +pub const MESHLET_COPY_MATERIAL_DEPTH_SHADER_HANDLE: Handle = + Handle::weak_from_u128(6325134235233421); #[derive(Resource)] pub struct MeshletPipelines { diff --git a/crates/bevy_pbr/src/meshlet/visibility_buffer_utils.wgsl b/crates/bevy_pbr/src/meshlet/visibility_buffer_utils.wgsl new file mode 100644 index 0000000000000..d1d4d67998522 --- /dev/null +++ b/crates/bevy_pbr/src/meshlet/visibility_buffer_utils.wgsl @@ -0,0 +1,93 @@ +#define_import_path bevy_pbr::meshlet_visibility_buffer_utils + +#import bevy_pbr::{ + meshlet_bindings::{meshlet_visibility_buffer, meshlet_thread_meshlet_ids, meshlets, meshlet_vertex_ids, meshlet_vertex_data, meshlet_thread_instance_ids, meshlet_instance_uniforms, unpack_meshlet_vertex}, + mesh_functions::mesh_position_local_to_world, + mesh_types::MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT, + view_transformations::{position_world_to_clip, frag_coord_to_ndc}, +} +#import bevy_pbr::mesh_view_bindings::view +#import bevy_render::maths::{affine_to_square, mat2x4_f32_to_mat3x3_unpack} + +struct PartialDerivatives { + barycentrics: vec3, + ddx: vec3, + ddy: vec3, +} + +// https://github.com/ConfettiFX/The-Forge/blob/2d453f376ef278f66f97cbaf36c0d12e4361e275/Examples_3/Visibility_Buffer/src/Shaders/FSL/visibilityBuffer_shade.frag.fsl#L83-L139 +fn compute_derivatives(vertex_clip_positions: array, 3>, ndc_uv: vec2, screen_size: vec2) -> PartialDerivatives { + var result: PartialDerivatives; + + let inv_w = 1.0 / vec3(vertex_clip_positions[0].w, vertex_clip_positions[1].w, vertex_clip_positions[2].w); + let ndc_0 = vertex_clip_positions[0].xy * inv_w[0]; + let ndc_1 = vertex_clip_positions[1].xy * inv_w[1]; + let ndc_2 = vertex_clip_positions[2].xy * inv_w[2]; + + let inv_det = 1.0 / determinant(mat2x2(ndc_2 - ndc_1, ndc_0 - ndc_1)); + result.ddx = vec3(ndc_1.y - ndc_2.y, ndc_2.y - ndc_0.y, ndc_0.y - ndc_1.y) * inv_det * inv_w; + result.ddy = vec3(ndc_2.x - ndc_1.x, ndc_0.x - ndc_2.x, ndc_1.x - ndc_0.x) * inv_det * inv_w; + + var ddx_sum = dot(result.ddx, vec3(1.0)); + var ddy_sum = dot(result.ddy, vec3(1.0)); + + let delta_v = ndc_uv - ndc_0; + let interp_inv_w = inv_w.x + delta_v.x * ddx_sum + delta_v.y * ddy_sum; + let interp_w = 1.0 / interp_inv_w; + + result.barycentrics = vec3( + interp_w * (delta_v.x * result.ddx.x + delta_v.y * result.ddy.x + inv_w.x), + interp_w * (delta_v.x * result.ddx.y + delta_v.y * result.ddy.y), + interp_w * (delta_v.x * result.ddx.z + delta_v.y * result.ddy.z), + ); + + result.ddx *= 2.0 / screen_size.x; + result.ddy *= 2.0 / screen_size.y; + ddx_sum *= 2.0 / screen_size.x; + ddy_sum *= 2.0 / screen_size.y; + + let interp_ddx_w = 1.0 / (interp_inv_w + ddx_sum); + let interp_ddy_w = 1.0 / (interp_inv_w + ddy_sum); + + result.ddx = interp_ddx_w * (result.barycentrics * interp_inv_w + result.ddx) - result.barycentrics; + result.ddy = interp_ddy_w * (result.barycentrics * interp_inv_w + result.ddy) - result.barycentrics; + return result; +} + +struct VertexOutput { + meshlet_id: u32, +} + +fn load_vertex_output(frag_coord: vec4) -> VertexOutput { + let vbuffer = textureLoad(meshlet_visibility_buffer, vec2(frag_coord.xy), 0).r; + let thread_id = vbuffer >> 8u; + let meshlet_id = meshlet_thread_meshlet_ids[thread_id]; + let meshlet = meshlets[meshlet_id]; + let triangle_id = extractBits(vbuffer, 0u, 8u); + + let indices = meshlet.start_vertex_id + vec3(triangle_id * 3u) + vec3(0u, 1u, 2u); + let vertex_ids = vec3(meshlet_vertex_ids[indices.x], meshlet_vertex_ids[indices.y], meshlet_vertex_ids[indices.z]); + let vertex_1 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.x]); + let vertex_2 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.y]); + let vertex_3 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.z]); + + let instance_id = meshlet_thread_instance_ids[thread_id]; + let instance_uniform = meshlet_instance_uniforms[instance_id]; + let model = affine_to_square(instance_uniform.model); + + let world_position_1 = mesh_position_local_to_world(model, vec4(vertex_1.position, 1.0)); + let world_position_2 = mesh_position_local_to_world(model, vec4(vertex_2.position, 1.0)); + let world_position_3 = mesh_position_local_to_world(model, vec4(vertex_3.position, 1.0)); + let clip_position_1 = position_world_to_clip(world_position_1.xyz); + let clip_position_2 = position_world_to_clip(world_position_2.xyz); + let clip_position_3 = position_world_to_clip(world_position_3.xyz); + + let partial_derivatives = compute_derivatives( + array(clip_position_1, clip_position_2, clip_position_3), + frag_coord_to_ndc(frag_coord).xy, + view.viewport.zw, + ); + + // TODO: Compute vertex output + return VertexOutput(meshlet_id); +}