Skip to content

Commit

Permalink
Fixes, move stuff to a seperate util file
Browse files Browse the repository at this point in the history
  • Loading branch information
JMS55 committed Nov 8, 2023
1 parent b39b225 commit 76d692f
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 87 deletions.
45 changes: 0 additions & 45 deletions crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -45,51 +45,6 @@ struct DrawIndexedIndirect {
base_instance: u32,
}

struct PartialDerivatives {
barycentrics: vec3<f32>,
ddx: vec3<f32>,
ddy: vec3<f32>,
}

// https://github.com/ConfettiFX/The-Forge/blob/2d453f376ef278f66f97cbaf36c0d12e4361e275/Examples_3/Visibility_Buffer/src/Shaders/FSL/visibilityBuffer_shade.frag.fsl#L83-L139
fn compute_derivatives(vertex_clip_positions: array<vec4<f32>, 3>, ndc_uv: vec2<f32>, screen_size: vec2<f32>) -> PartialDerivatives {
var result: PartialDerivatives;

let inv_w = 1.0 / vec3(vertex_clip_positions[0].w, vertex_clip_positions[1].w, vertex_clip_positions[2].w);
let ndc_0 = vertex_clip_positions[0].xy * inv_w[0];
let ndc_1 = vertex_clip_positions[1].xy * inv_w[1];
let ndc_2 = vertex_clip_positions[2].xy * inv_w[2];

let inv_det = 1.0 / determinant(mat2x2(ndc_2 - ndc_1, ndc_0 - ndc_1));
result.ddx = vec3(ndc_1.y - ndc_2.y, ndc_2.y - ndc_0.y, ndc_0.y - ndc_1.y) * inv_det * inv_w;
result.ddy = vec3(ndc_2.x - ndc_1.x, ndc_0.x - ndc_2.x, ndc_1.x - ndc_0.x) * inv_det * inv_w;

var ddx_sum = dot(result.ddx, vec3(1.0));
var ddy_sum = dot(result.ddy, vec3(1.0));

let delta_v = ndc_uv - ndc_0;
let interp_inv_w = inv_w.x + delta_v.x * ddx_sum + delta_v.y * ddy_sum;
let interp_w = 1.0 / interp_inv_w;

result.barycentrics = vec3(
interp_w * (delta_v.x * result.ddx.x + delta_v.y * result.ddy.x + inv_w.x),
interp_w * (delta_v.x * result.ddx.y + delta_v.y * result.ddy.y),
interp_w * (delta_v.x * result.ddx.z + delta_v.y * result.ddy.z),
);

result.ddx *= 2.0 / screen_size.x;
result.ddy *= 2.0 / screen_size.y;
ddx_sum *= 2.0 / screen_size.x;
ddy_sum *= 2.0 / screen_size.y;

let interp_ddx_w = 1.0 / (interp_inv_w + ddx_sum);
let interp_ddy_w = 1.0 / (interp_inv_w + ddy_sum);

result.ddx = interp_ddx_w * (result.barycentrics * interp_inv_w + result.ddx) - result.barycentrics;
result.ddy = interp_ddy_w * (result.barycentrics * interp_inv_w + result.ddy) - result.barycentrics;
return result;
}

@group(#{MESHLET_BIND_GROUP}) @binding(0) var<storage, read> meshlets: array<Meshlet>;
@group(#{MESHLET_BIND_GROUP}) @binding(1) var<storage, read> meshlet_instance_uniforms: array<Mesh>;
@group(#{MESHLET_BIND_GROUP}) @binding(2) var<storage, read> meshlet_thread_instance_ids: array<u32>;
Expand Down
43 changes: 5 additions & 38 deletions crates/bevy_pbr/src/meshlet/meshlet_mesh_material.wgsl
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
#import bevy_pbr::{
meshlet_bindings::{meshlet_visibility_buffer, meshlet_thread_meshlet_ids, meshlets, meshlet_vertex_ids, meshlet_vertex_data, meshlet_thread_instance_ids, meshlet_instance_uniforms, unpack_meshlet_vertex, compute_derivatives},
mesh_functions::mesh_position_local_to_world,
mesh_types::MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT,
view_transformations::{uv_to_ndc, position_world_to_clip},
view_transformations::uv_to_ndc,
meshlet_visibility_buffer_utils::load_vertex_output,
}
#import bevy_pbr::mesh_view_bindings::view
#import bevy_render::maths::{affine_to_square, mat2x4_f32_to_mat3x3_unpack}

fn rand_f(state: ptr<function, u32>) -> f32 {
*state = *state * 747796405u + 2891336453u;
Expand All @@ -23,39 +19,10 @@ fn vertex(@builtin(vertex_index) vertex_input: u32) -> @builtin(position) vec4<f
}

@fragment
fn fragment(@builtin(position) clip_position: vec4<f32>) -> @location(0) vec4<f32> {
let vbuffer = textureLoad(meshlet_visibility_buffer, vec2<i32>(clip_position.xy), 0).r;
let thread_id = vbuffer >> 8u;
let meshlet_id = meshlet_thread_meshlet_ids[thread_id];
let meshlet = meshlets[meshlet_id];
let triangle_id = extractBits(vbuffer, 0u, 8u);
fn fragment(@builtin(position) frag_coord: vec4<f32>) -> @location(0) vec4<f32> {
let vertex_output = load_vertex_output(frag_coord);

let indices = meshlet.start_vertex_id + vec3(triangle_id * 3u) + vec3(0u, 1u, 2u);
let vertex_ids = vec3(meshlet_vertex_ids[indices.x], meshlet_vertex_ids[indices.y], meshlet_vertex_ids[indices.z]);
let vertex_1 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.x]);
let vertex_2 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.y]);
let vertex_3 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.z]);

let instance_id = meshlet_thread_instance_ids[thread_id];
let instance_uniform = meshlet_instance_uniforms[instance_id];
let model = affine_to_square(instance_uniform.model);

let world_position_1 = mesh_position_local_to_world(model, vec4(vertex_1.position, 1.0));
let world_position_2 = mesh_position_local_to_world(model, vec4(vertex_2.position, 1.0));
let world_position_3 = mesh_position_local_to_world(model, vec4(vertex_3.position, 1.0));
let clip_position_1 = position_world_to_clip(world_position_1.xyz);
let clip_position_2 = position_world_to_clip(world_position_2.xyz);
let clip_position_3 = position_world_to_clip(world_position_3.xyz);

let partial_derivatives = compute_derivatives(
array(clip_position_1, clip_position_2, clip_position_3),
clip_position.xy,
view.viewport.zw,
);

// TODO: Compute vertex output

var rng = meshlet_id;
var rng = vertex_output.meshlet_id;
let color = vec3(rand_f(&rng), rand_f(&rng), rand_f(&rng));
return vec4(color, 1.0);
}
10 changes: 9 additions & 1 deletion crates/bevy_pbr/src/meshlet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ use bevy_render::{
use bevy_transform::components::{GlobalTransform, Transform};

const MESHLET_BINDINGS_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(1325134235233421);
pub(crate) const MESHLET_MESH_MATERIAL_SHADER_HANDLE: Handle<Shader> =
const MESHLET_VISIBILITY_BUFFER_UTILS_SHADER_HANDLE: Handle<Shader> =
Handle::weak_from_u128(2325134235233421);
pub(crate) const MESHLET_MESH_MATERIAL_SHADER_HANDLE: Handle<Shader> =
Handle::weak_from_u128(3325134235233421);

pub struct MeshletPlugin;

Expand All @@ -61,6 +63,12 @@ impl Plugin for MeshletPlugin {
"meshlet_bindings.wgsl",
Shader::from_wgsl
);
load_internal_asset!(
app,
MESHLET_VISIBILITY_BUFFER_UTILS_SHADER_HANDLE,
"visibility_buffer_utils.wgsl",
Shader::from_wgsl
);
load_internal_asset!(
app,
MESHLET_CULLING_SHADER_HANDLE,
Expand Down
6 changes: 3 additions & 3 deletions crates/bevy_pbr/src/meshlet/pipelines.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ use bevy_ecs::{
};
use bevy_render::render_resource::*;

pub const MESHLET_CULLING_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(3325134235233421);
pub const MESHLET_CULLING_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(4325134235233421);
pub const MESHLET_VISIBILITY_BUFFER_SHADER_HANDLE: Handle<Shader> =
Handle::weak_from_u128(4325134235233421);
pub const MESHLET_COPY_MATERIAL_DEPTH_SHADER_HANDLE: Handle<Shader> =
Handle::weak_from_u128(5325134235233421);
pub const MESHLET_COPY_MATERIAL_DEPTH_SHADER_HANDLE: Handle<Shader> =
Handle::weak_from_u128(6325134235233421);

#[derive(Resource)]
pub struct MeshletPipelines {
Expand Down
93 changes: 93 additions & 0 deletions crates/bevy_pbr/src/meshlet/visibility_buffer_utils.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#define_import_path bevy_pbr::meshlet_visibility_buffer_utils

#import bevy_pbr::{
meshlet_bindings::{meshlet_visibility_buffer, meshlet_thread_meshlet_ids, meshlets, meshlet_vertex_ids, meshlet_vertex_data, meshlet_thread_instance_ids, meshlet_instance_uniforms, unpack_meshlet_vertex},
mesh_functions::mesh_position_local_to_world,
mesh_types::MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT,
view_transformations::{position_world_to_clip, frag_coord_to_ndc},
}
#import bevy_pbr::mesh_view_bindings::view
#import bevy_render::maths::{affine_to_square, mat2x4_f32_to_mat3x3_unpack}

struct PartialDerivatives {
barycentrics: vec3<f32>,
ddx: vec3<f32>,
ddy: vec3<f32>,
}

// https://github.com/ConfettiFX/The-Forge/blob/2d453f376ef278f66f97cbaf36c0d12e4361e275/Examples_3/Visibility_Buffer/src/Shaders/FSL/visibilityBuffer_shade.frag.fsl#L83-L139
fn compute_derivatives(vertex_clip_positions: array<vec4<f32>, 3>, ndc_uv: vec2<f32>, screen_size: vec2<f32>) -> PartialDerivatives {
var result: PartialDerivatives;

let inv_w = 1.0 / vec3(vertex_clip_positions[0].w, vertex_clip_positions[1].w, vertex_clip_positions[2].w);
let ndc_0 = vertex_clip_positions[0].xy * inv_w[0];
let ndc_1 = vertex_clip_positions[1].xy * inv_w[1];
let ndc_2 = vertex_clip_positions[2].xy * inv_w[2];

let inv_det = 1.0 / determinant(mat2x2(ndc_2 - ndc_1, ndc_0 - ndc_1));
result.ddx = vec3(ndc_1.y - ndc_2.y, ndc_2.y - ndc_0.y, ndc_0.y - ndc_1.y) * inv_det * inv_w;
result.ddy = vec3(ndc_2.x - ndc_1.x, ndc_0.x - ndc_2.x, ndc_1.x - ndc_0.x) * inv_det * inv_w;

var ddx_sum = dot(result.ddx, vec3(1.0));
var ddy_sum = dot(result.ddy, vec3(1.0));

let delta_v = ndc_uv - ndc_0;
let interp_inv_w = inv_w.x + delta_v.x * ddx_sum + delta_v.y * ddy_sum;
let interp_w = 1.0 / interp_inv_w;

result.barycentrics = vec3(
interp_w * (delta_v.x * result.ddx.x + delta_v.y * result.ddy.x + inv_w.x),
interp_w * (delta_v.x * result.ddx.y + delta_v.y * result.ddy.y),
interp_w * (delta_v.x * result.ddx.z + delta_v.y * result.ddy.z),
);

result.ddx *= 2.0 / screen_size.x;
result.ddy *= 2.0 / screen_size.y;
ddx_sum *= 2.0 / screen_size.x;
ddy_sum *= 2.0 / screen_size.y;

let interp_ddx_w = 1.0 / (interp_inv_w + ddx_sum);
let interp_ddy_w = 1.0 / (interp_inv_w + ddy_sum);

result.ddx = interp_ddx_w * (result.barycentrics * interp_inv_w + result.ddx) - result.barycentrics;
result.ddy = interp_ddy_w * (result.barycentrics * interp_inv_w + result.ddy) - result.barycentrics;
return result;
}

struct VertexOutput {
meshlet_id: u32,
}

fn load_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
let vbuffer = textureLoad(meshlet_visibility_buffer, vec2<i32>(frag_coord.xy), 0).r;
let thread_id = vbuffer >> 8u;
let meshlet_id = meshlet_thread_meshlet_ids[thread_id];
let meshlet = meshlets[meshlet_id];
let triangle_id = extractBits(vbuffer, 0u, 8u);

let indices = meshlet.start_vertex_id + vec3(triangle_id * 3u) + vec3(0u, 1u, 2u);
let vertex_ids = vec3(meshlet_vertex_ids[indices.x], meshlet_vertex_ids[indices.y], meshlet_vertex_ids[indices.z]);
let vertex_1 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.x]);
let vertex_2 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.y]);
let vertex_3 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.z]);

let instance_id = meshlet_thread_instance_ids[thread_id];
let instance_uniform = meshlet_instance_uniforms[instance_id];
let model = affine_to_square(instance_uniform.model);

let world_position_1 = mesh_position_local_to_world(model, vec4(vertex_1.position, 1.0));
let world_position_2 = mesh_position_local_to_world(model, vec4(vertex_2.position, 1.0));
let world_position_3 = mesh_position_local_to_world(model, vec4(vertex_3.position, 1.0));
let clip_position_1 = position_world_to_clip(world_position_1.xyz);
let clip_position_2 = position_world_to_clip(world_position_2.xyz);
let clip_position_3 = position_world_to_clip(world_position_3.xyz);

let partial_derivatives = compute_derivatives(
array(clip_position_1, clip_position_2, clip_position_3),
frag_coord_to_ndc(frag_coord).xy,
view.viewport.zw,
);

// TODO: Compute vertex output
return VertexOutput(meshlet_id);
}

0 comments on commit 76d692f

Please sign in to comment.