From 24febb753a0f65dc3ffc6048e1a1a4f0435a16c7 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sat, 14 Sep 2024 22:53:18 -0700 Subject: [PATCH] Multi-pass architecture for ReSTIR The goal of this approach is to aggressively re-use code, assuming that the driver will inline everything. Therefore, the whole pipeline is shaped as a loop over passes. --- blade-render/code/debug-param.inc.wgsl | 2 +- blade-render/code/ray-trace.wgsl | 261 +++++++++++++------------ blade-render/src/render/mod.rs | 11 +- 3 files changed, 138 insertions(+), 136 deletions(-) diff --git a/blade-render/code/debug-param.inc.wgsl b/blade-render/code/debug-param.inc.wgsl index 544c6af4..4904f1cd 100644 --- a/blade-render/code/debug-param.inc.wgsl +++ b/blade-render/code/debug-param.inc.wgsl @@ -4,8 +4,8 @@ struct DebugParams { view_mode: u32, + pass_index: u32, draw_flags: u32, texture_flags: u32, - pad: u32, mouse_pos: vec2, }; diff --git a/blade-render/code/ray-trace.wgsl b/blade-render/code/ray-trace.wgsl index 1d1c6dca..ace6f1d4 100644 --- a/blade-render/code/ray-trace.wgsl +++ b/blade-render/code/ray-trace.wgsl @@ -232,12 +232,15 @@ fn evaluate_brdf(surface: Surface, dir: vec3) -> f32 { return lambert_brdf * max(0.0, lambert_term); } -fn check_ray_occluded(acs: acceleration_structure, position: vec3, direction: vec3) -> bool { +fn check_ray_occluded(prev_frame: bool, position: vec3, direction: vec3) -> bool { var rq: ray_query; let flags = RAY_FLAG_TERMINATE_ON_FIRST_HIT | RAY_FLAG_CULL_NO_OPAQUE; - rayQueryInitialize(&rq, acs, - RayDesc(flags, 0xFFu, parameters.t_start, camera.depth, position, direction) - ); + let desc = RayDesc(flags, 0xFFu, parameters.t_start, camera.depth, position, direction); + if (prev_frame) { + rayQueryInitialize(&rq, prev_acc_struct, desc); + } else { + rayQueryInitialize(&rq, acc_struct, desc); + } rayQueryProceed(&rq); let intersection = rayQueryGetCommittedIntersection(&rq); @@ -273,7 +276,7 @@ fn make_target_score(color: vec3) -> TargetScore { } fn estimate_target_score_with_occlusion( - surface: Surface, position: vec3, light_index: u32, light_uv: vec2, acs: acceleration_structure, + surface: Surface, position: vec3, light_index: u32, light_uv: vec2, prev_frame: bool, ) -> TargetScore { if (light_index != 0u) { return TargetScore(); @@ -287,7 +290,7 @@ fn estimate_target_score_with_occlusion( return TargetScore(); } - if (check_ray_occluded(acs, position, direction)) { + if (check_ray_occluded(prev_frame, position, direction)) { return TargetScore(); } @@ -312,7 +315,7 @@ fn evaluate_sample(ls: LightSample, surface: Surface, start_pos: vec3) -> f return 0.0; } - if (check_ray_occluded(acc_struct, start_pos, dir)) { + if (check_ray_occluded(false, start_pos, dir)) { return 0.0; } @@ -402,62 +405,54 @@ struct ResampleBase { world_pos: vec3, accepted_count: f32, } -struct ResampleResult { - selected: bool, + +struct ShiftSample { + reservoir: LiveReservoir, mis_canonical: f32, mis_sample: f32, } // Resample following Algorithm 8 in section 9.1 of Bitterli thesis -fn resample( - dst: ptr, color_and_weight: ptr>, - base: ResampleBase, other: PixelCache, other_acs: acceleration_structure, +fn shift_sample( + base: ResampleBase, other: PixelCache, other_prev_frame: bool, max_confidence: f32, -) -> ResampleResult { - var src: LiveReservoir; +) -> ShiftSample { + var ss = ShiftSample(); let neighbor = other.reservoir; - var rr = ResampleResult(); if (parameters.use_pairwise_mis != 0u) { let canonical = base.canonical; let neighbor_history = min(neighbor.confidence, max_confidence); { // scoping this to hint the register allocation let t_canonical_at_neighbor = estimate_target_score_with_occlusion( - other.surface, other.world_pos, canonical.selected_light_index, canonical.selected_uv, other_acs); + other.surface, other.world_pos, canonical.selected_light_index, canonical.selected_uv, other_prev_frame); let nom = canonical.selected_target_score * canonical.history / base.accepted_count; let denom = t_canonical_at_neighbor.score * neighbor_history + nom; - rr.mis_canonical = select(0.0, nom / denom, denom > 0.0); + ss.mis_canonical = select(0.0, nom / denom, denom > 0.0); } + let canonical_prev_frame = false; let t_neighbor_at_canonical = estimate_target_score_with_occlusion( - base.surface, base.world_pos, neighbor.light_index, neighbor.light_uv, acc_struct); + base.surface, base.world_pos, neighbor.light_index, neighbor.light_uv, canonical_prev_frame); let nom = neighbor.target_score * neighbor_history; let denom = nom + t_neighbor_at_canonical.score * canonical.history / base.accepted_count; let mis_neighbor = select(0.0, nom / denom, denom > 0.0); - rr.mis_sample = mis_neighbor; + ss.mis_sample = mis_neighbor; + var src: LiveReservoir; src.history = neighbor_history; src.selected_light_index = neighbor.light_index; src.selected_uv = neighbor.light_uv; src.selected_target_score = t_neighbor_at_canonical.score; src.weight_sum = t_neighbor_at_canonical.score * neighbor.contribution_weight * mis_neighbor; src.radiance = t_neighbor_at_canonical.color; + ss.reservoir = src; } else { - rr.mis_canonical = 0.0; - rr.mis_sample = 1.0; + ss.mis_canonical = 0.5; + ss.mis_sample = 0.5; let radiance = evaluate_reflected_light(base.surface, neighbor.light_index, neighbor.light_uv); - src = unpack_reservoir(neighbor, max_confidence, radiance); - } - - if (DECOUPLED_SHADING) { - *color_and_weight += src.weight_sum * vec4(neighbor.contribution_weight * src.radiance, 1.0); - } - if (src.weight_sum <= 0.0) { - bump_reservoir(dst, src.history); - } else { - merge_reservoir(dst, src); - rr.selected = true; + ss.reservoir = unpack_reservoir(neighbor, max_confidence, radiance); } - return rr; + return ss; } struct ResampleOutput { @@ -503,115 +498,123 @@ fn finalize_resampling( return ro; } -fn resample_temporal( - surface: Surface, cur_pixel: vec2, position: vec3, - local_index: u32, tr: TemporalReprojection, -) -> ResampleOutput { - if (surface.depth == 0.0) { - return ResampleOutput(); - } - - let canonical = produce_canonical(surface, position); - if (parameters.temporal_tap == 0u || !tr.is_valid) { - return finalize_canonical(canonical); - } - - var reservoir = LiveReservoir(); - var color_and_weight = vec4(0.0); - let base = ResampleBase(surface, canonical, position, 1.0); - - let prev_dir = get_ray_direction(prev_camera, tr.pixel); - let prev_world_pos = prev_camera.position + tr.surface.depth * prev_dir; - let other = PixelCache(tr.surface, tr.reservoir, prev_world_pos); - let rr = resample(&reservoir, &color_and_weight, base, other, prev_acc_struct, parameters.temporal_tap_confidence); - let mis_canonical = 1.0 + rr.mis_canonical; - - if (WRITE_DEBUG_IMAGE && debug.view_mode == DebugMode_TemporalMatch) { - textureStore(out_debug, cur_pixel, vec4(1.0)); - } - if (WRITE_DEBUG_IMAGE && debug.view_mode == DebugMode_TemporalMisCanonical) { - let mis = mis_canonical / (1.0 + base.accepted_count); - textureStore(out_debug, cur_pixel, vec4(mis)); - } - - return finalize_resampling(&reservoir, &color_and_weight, base, mis_canonical); -} - -fn resample_spatial( - surface: Surface, cur_pixel: vec2, position: vec3, - group_id: vec3, canonical: LiveReservoir, -) -> ResampleOutput { - if (surface.depth == 0.0) { - let dir = normalize(position - camera.position); - var ro = ResampleOutput(); - ro.color = evaluate_environment(dir); - return ro; - } - - // gather the list of neighbors (within the workgroup) to resample. - var accepted_count = 0u; - var accepted_local_indices = array(); - let max_accepted = min(MAX_RESAMPLE, parameters.spatial_taps); - let num_candidates = parameters.spatial_taps * 4u; - for (var i = 0u; i < num_candidates && accepted_count < max_accepted; i += 1u) { - let other_cache_index = random_u32(&p_rng) % GROUP_SIZE_TOTAL; - let diff = thread_index_to_coord(other_cache_index, group_id) - cur_pixel; - if (dot(diff, diff) < parameters.spatial_min_distance * parameters.spatial_min_distance) { - continue; - } - let other = pixel_cache[other_cache_index]; - // if the surfaces are too different, there is no trust in this sample - if (other.reservoir.confidence > 0.0 && compare_surfaces(surface, other.surface) > 0.1) { - accepted_local_indices[accepted_count] = other_cache_index; - accepted_count += 1u; - } - } - - var reservoir = LiveReservoir(); - var color_and_weight = vec4(0.0); - let base = ResampleBase(surface, canonical, position, f32(accepted_count)); - var mis_canonical = 1.0; - - // evaluate the MIS of each of the samples versus the canonical one. - for (var lid = 0u; lid < accepted_count; lid += 1u) { - let other = pixel_cache[accepted_local_indices[lid]]; - let rr = resample(&reservoir, &color_and_weight, base, other, acc_struct, parameters.spatial_tap_confidence); - mis_canonical += rr.mis_canonical; - } - - if (WRITE_DEBUG_IMAGE && debug.view_mode == DebugMode_SpatialMatch) { - let value = base.accepted_count / max(1.0, f32(parameters.spatial_taps)); - textureStore(out_debug, cur_pixel, vec4(value)); - } - if (WRITE_DEBUG_IMAGE && debug.view_mode == DebugMode_SpatialMisCanonical) { - let mis = mis_canonical / (1.0 + base.accepted_count); - textureStore(out_debug, cur_pixel, vec4(mis)); - } - return finalize_resampling(&reservoir, &color_and_weight, base, mis_canonical); +struct Pass { + is_temporal: bool, + confidence: f32, + taps: u32, + candidates: u32, } fn compute_restir( rs: RichSurface, pixel: vec2, local_index: u32, group_id: vec3, ) -> vec3 { let center_coord = vec2(pixel) + 0.5 + select(vec2(0.0), rs.motion, parameters.use_motion_vectors != 0u); + //TODO: recompute this at the end? let tr = find_temporal(rs.inner, pixel, center_coord); - let motion_sqr = dot(rs.motion, rs.motion); - - let temporal = resample_temporal(rs.inner, pixel, rs.position, local_index, tr); - pixel_cache[local_index] = PixelCache(rs.inner, temporal.reservoir, rs.position); var prev_pixel = select(vec2(-1), tr.pixel, tr.is_valid); + let motion_sqr = dot(rs.motion, rs.motion); - // sync with the workgroup to ensure all reservoirs are available. - workgroupBarrier(); + var result = ResampleOutput(); + if (rs.inner.depth == 0.0) { + let dir = normalize(rs.position - camera.position); + result.color = evaluate_environment(dir); + } else { + let canonical = produce_canonical(rs.inner, rs.position); + result = finalize_canonical(canonical); + + var num_passes = 0u; + var passes = array(); + if (parameters.temporal_tap != 0u) { + passes[num_passes] = Pass(true, parameters.temporal_tap_confidence, 1, 0); + num_passes += 1u; + } + if (parameters.spatial_taps > 0) { + passes[num_passes] = Pass(false, parameters.spatial_tap_confidence, parameters.spatial_taps, parameters.spatial_taps * 4u); + num_passes += 1u; + } - let temporal_live = revive_canonical(temporal); - let spatial = resample_spatial(rs.inner, pixel, rs.position, group_id, temporal_live); + for(var pass_i = 0u; pass_i < num_passes; pass_i += 1u) { + let ps = passes[pass_i]; + var reservoir = LiveReservoir(); + var color_and_weight = vec4(0.0); + var mis_canonical = 0.0; + var accepted_count = 0u; + var accepted_local_indices = array(); + + if (ps.is_temporal) { + if (tr.is_valid) { + let prev_dir = get_ray_direction(prev_camera, tr.pixel); + let prev_world_pos = prev_camera.position + tr.surface.depth * prev_dir; + pixel_cache[local_index] = PixelCache(tr.surface, tr.reservoir, prev_world_pos); + accepted_local_indices[0] = local_index; + accepted_count += 1u; + } + } else { + pixel_cache[local_index] = PixelCache(rs.inner, result.reservoir, rs.position); + // sync with the workgroup to ensure all reservoirs are available. + workgroupBarrier(); + + // gather the list of neighbors (within the workgroup) to resample. + let max_accepted = min(MAX_RESAMPLE, ps.taps); + for (var i = 0u; i < ps.candidates && accepted_count < max_accepted; i += 1u) { + let other_cache_index = random_u32(&p_rng) % GROUP_SIZE_TOTAL; + let diff = thread_index_to_coord(other_cache_index, group_id) - pixel; + if (dot(diff, diff) < parameters.spatial_min_distance * parameters.spatial_min_distance) { + continue; + } + let other = pixel_cache[other_cache_index]; + // if the surfaces are too different, there is no trust in this sample + if (other.reservoir.confidence > 0.0 && compare_surfaces(rs.inner, other.surface) > 0.1) { + accepted_local_indices[accepted_count] = other_cache_index; + accepted_count += 1u; + } + } + } + + if (accepted_count == 0u) { + continue; + } + + let input = revive_canonical(result); + let base = ResampleBase(rs.inner, input, rs.position, f32(accepted_count)); + + mis_canonical = 1.0; + // evaluate the MIS of each of the samples versus the canonical one. + for (var lid = 0u; lid < accepted_count; lid += 1u) { + let other = pixel_cache[accepted_local_indices[lid]]; + + let ss = shift_sample(base, other, ps.is_temporal, ps.confidence); + mis_canonical += ss.mis_canonical; + + if (DECOUPLED_SHADING) { + let stored = pack_reservoir(ss.reservoir); + color_and_weight += ss.reservoir.weight_sum * vec4(stored.contribution_weight * ss.reservoir.radiance, 1.0); + } + if (ss.reservoir.weight_sum <= 0.0) { + bump_reservoir(&reservoir, ss.reservoir.history); + } else { + merge_reservoir(&reservoir, ss.reservoir); + } + } + + if (WRITE_DEBUG_IMAGE && pass_i == debug.pass_index) { + if (debug.view_mode == DebugMode_PassMatch) { + textureStore(out_debug, pixel, vec4(1.0)); + } + if (debug.view_mode == DebugMode_PassMisCanonical) { + let mis = mis_canonical / f32(1u + accepted_count); + textureStore(out_debug, pixel, vec4(mis)); + } + } + result = finalize_resampling(&reservoir, &color_and_weight, base, mis_canonical); + } + } let pixel_index = get_reservoir_index(pixel, camera); - reservoirs[pixel_index] = spatial.reservoir; + reservoirs[pixel_index] = result.reservoir; - accumulate_temporal(pixel, spatial.color, parameters.temporal_accumulation_weight, prev_pixel, motion_sqr); - return spatial.color; + accumulate_temporal(pixel, result.color, parameters.temporal_accumulation_weight, prev_pixel, motion_sqr); + return result.color; } @compute @workgroup_size(GROUP_SIZE.x, GROUP_SIZE.y) diff --git a/blade-render/src/render/mod.rs b/blade-render/src/render/mod.rs index 4b18eecd..5a5cb00f 100644 --- a/blade-render/src/render/mod.rs +++ b/blade-render/src/render/mod.rs @@ -53,10 +53,8 @@ pub enum DebugMode { HitConsistency = 4, Grouping = 5, Reprojection = 6, - TemporalMatch = 10, - TemporalMisCanonical = 11, - SpatialMatch = 12, - SpatialMisCanonical = 13, + PassMatch = 10, + PassMisCanonical = 11, Variance = 100, } @@ -86,6 +84,7 @@ bitflags::bitflags! { #[derive(Clone, Copy, Debug, Default)] pub struct DebugConfig { pub view_mode: DebugMode, + pub pass_index: u32, pub draw_flags: DebugDrawFlags, pub texture_flags: DebugTextureFlags, pub mouse_pos: Option<[i32; 2]>, @@ -355,9 +354,9 @@ struct CameraParams { #[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)] struct DebugParams { view_mode: u32, + pass_index: u32, draw_flags: u32, texture_flags: u32, - unused: u32, mouse_pos: [i32; 2], } @@ -972,9 +971,9 @@ impl Renderer { fn make_debug_params(&self, config: &DebugConfig) -> DebugParams { DebugParams { view_mode: config.view_mode as u32, + pass_index: config.pass_index, draw_flags: config.draw_flags.bits(), texture_flags: config.texture_flags.bits(), - unused: 0, mouse_pos: match config.mouse_pos { Some(p) => [p[0], self.surface_size.height as i32 - p[1]], None => [-1; 2],