Skip to content

Commit

Permalink
Fix groups, canonical sampling, refactor MIS
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Sep 3, 2024
1 parent c39a7b1 commit ddee75e
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 78 deletions.
17 changes: 11 additions & 6 deletions blade-graphics/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,19 @@ impl super::TextureFormat {
}
}

impl super::Extent {
pub fn group_by(&self, size: [u32; 3]) -> [u32; 3] {
[
(self.width + size[0] - 1) / size[0],
(self.height + size[1] - 1) / size[1],
(self.depth + size[2] - 1) / size[2],
]
}
}

impl super::ComputePipeline {
/// Return the dispatch group counts sufficient to cover the given extent.
pub fn get_dispatch_for(&self, extent: super::Extent) -> [u32; 3] {
let wg_size = self.get_workgroup_size();
[
(extent.width + wg_size[0] - 1) / wg_size[0],
(extent.height + wg_size[1] - 1) / wg_size[1],
(extent.depth + wg_size[2] - 1) / wg_size[2],
]
extent.group_by(self.get_workgroup_size())
}
}
180 changes: 111 additions & 69 deletions blade-render/code/ray-trace.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ const PI: f32 = 3.1415926;
const MAX_RESAMPLE: u32 = 4u;
// See "9.1 pairwise mis for robust reservoir reuse"
// "Correlations and Reuse for Fast and Accurate Physically Based Light Transport"
const PAIRWISE_MIS: bool = true;
const DEFENSIVE_MIS: bool = false;
const PAIRWISE_MIS: bool = false;
// See "DECOUPLING SHADING AND REUSE" in
// "Rearchitecting Spatiotemporal Resampling for Production"
const DECOUPLED_SHADING: bool = false;
Expand Down Expand Up @@ -121,12 +120,12 @@ fn merge_reservoir(r: ptr<function, LiveReservoir>, other: LiveReservoir, random
return false;
}
}
fn unpack_reservoir(f: StoredReservoir, max_history: u32) -> LiveReservoir {
fn unpack_reservoir(f: StoredReservoir, max_history: u32, radiance: vec3<f32>) -> LiveReservoir {
var r: LiveReservoir;
r.selected_light_index = f.light_index;
r.selected_uv = f.light_uv;
r.selected_target_score = f.target_score;
r.radiance = vec3<f32>(0.0); // to be continued...
r.radiance = radiance;
let history = min(f.confidence, f32(max_history));
r.weight_sum = f.contribution_weight * f.target_score * history;
r.history = history;
Expand Down Expand Up @@ -305,11 +304,11 @@ fn estimate_target_score_with_occlusion(

if (check_ray_occluded(acs, position, direction, debug_len)) {
return TargetScore();
} else {
//Note: same as `evaluate_reflected_light`
let radiance = textureSampleLevel(env_map, sampler_nearest, light_uv, 0.0).xyz;
return make_target_score(brdf * radiance);
}

//Note: same as `evaluate_reflected_light`
let radiance = textureSampleLevel(env_map, sampler_nearest, light_uv, 0.0).xyz;
return make_target_score(brdf * radiance);
}

fn evaluate_sample(ls: LightSample, surface: Surface, start_pos: vec3<f32>, debug_len: f32) -> f32 {
Expand All @@ -335,6 +334,30 @@ fn evaluate_sample(ls: LightSample, surface: Surface, start_pos: vec3<f32>, debu
return brdf;
}

fn produce_canonical(
surface: Surface, position: vec3<f32>,
rng: ptr<function, RandomState>, debug_len: f32,
) -> LiveReservoir {
var reservoir = LiveReservoir();
for (var i = 0u; i < parameters.num_environment_samples; i += 1u) {
var ls: LightSample;
if (parameters.environment_importance_sampling != 0u) {
ls = sample_light_from_environment(rng);
} else {
ls = sample_light_from_sphere(rng);
}

let brdf = evaluate_sample(ls, surface, position, debug_len);
if (brdf > 0.0) {
let other = make_reservoir(ls, 0u, vec3<f32>(brdf));
merge_reservoir(&reservoir, other, random_gen(rng));
} else {
bump_reservoir(&reservoir, 1.0);
}
}
return reservoir;
}

struct ResampleBase {
surface: Surface,
canonical: LiveReservoir,
Expand All @@ -344,29 +367,27 @@ struct ResampleBase {
struct ResampleResult {
selected: bool,
mis_canonical: f32,
mis_sample: f32,
}

const canonical_count: f32 = 1.0;

// Resample following Algorithm 8 in section 9.1 of Bitterli thesis
fn resample(
dst: ptr<function, LiveReservoir>, color_and_weight: ptr<function, vec4<f32>>,
base: ResampleBase, other: PixelCache, other_acs: acceleration_structure, max_history: u32,
rng: ptr<function, RandomState>, enable_debug: bool,
rng: ptr<function, RandomState>, debug_len: f32,
) -> ResampleResult {
var src: LiveReservoir;
let neighbor = other.reservoir;
var rr = ResampleResult();
if (PAIRWISE_MIS) {
let debug_len = select(0.0, other.surface.depth * 0.2, enable_debug);
let canonical = base.canonical;
let neighbor_history = min(neighbor.confidence, f32(max_history));
{ // scoping this to hint the register allocation
let t_canonical_at_neighbor = estimate_target_score_with_occlusion(
other.surface, other.world_pos, canonical.selected_light_index, canonical.selected_uv, other_acs, debug_len);
let nom = canonical.selected_target_score * canonical.history;
let denom = canonical_count * nom + t_canonical_at_neighbor.score * neighbor_history * base.accepted_count;
let kf = 1.0 / select(base.accepted_count, canonical_count + base.accepted_count, DEFENSIVE_MIS);
rr.mis_canonical = kf * nom / max(0.01, denom);
let nom = canonical.selected_target_score * canonical.history / base.accepted_count;
let denom = t_canonical_at_neighbor.score * neighbor_history + nom;
rr.mis_canonical = select(0.0, nom / denom, denom > 0.0);
}

// Notes about t_neighbor_at_neighbor:
Expand All @@ -376,10 +397,10 @@ fn resample(
//let t_neighbor_at_neighbor = estimate_target_pdf(neighbor_surface, neighbor_position, neighbor.selected_dir);
let t_neighbor_at_canonical = estimate_target_score_with_occlusion(
base.surface, base.world_pos, neighbor.light_index, neighbor.light_uv, acc_struct, debug_len);
let nom = t_neighbor_at_canonical.score * canonical.history;
let denom = canonical_count * neighbor.target_score * neighbor_history + base.accepted_count * nom;
let kf = select(1.0, base.accepted_count / (canonical_count + base.accepted_count), DEFENSIVE_MIS);
let mis_neighbor = kf * nom / max(0.01, denom);
let nom = neighbor.target_score * neighbor_history;
let denom = nom + t_neighbor_at_canonical.score * canonical.history / base.accepted_count;
let mis_neighbor = select(0.0, nom / denom, denom > 0.0);
rr.mis_sample = mis_neighbor;

src.history = neighbor_history;
src.selected_light_index = neighbor.light_index;
Expand All @@ -388,11 +409,14 @@ fn resample(
src.weight_sum = t_neighbor_at_canonical.score * neighbor.contribution_weight * mis_neighbor;
src.radiance = t_neighbor_at_canonical.color;
} else {
src = unpack_reservoir(neighbor, max_history);
src.radiance = evaluate_reflected_light(base.surface, src.selected_light_index, src.selected_uv);
rr.mis_canonical = 0.0;
rr.mis_sample = 1.0;
let radiance = evaluate_reflected_light(base.surface, neighbor.light_index, neighbor.light_uv);
src = unpack_reservoir(neighbor, max_history, radiance);
}

if (DECOUPLED_SHADING) {
//TODO: use `mis_neighbor`O
*color_and_weight += src.weight_sum * vec4<f32>(neighbor.contribution_weight * src.radiance, 1.0);
}
if (src.weight_sum <= 0.0) {
Expand All @@ -409,29 +433,39 @@ struct ResampleOutput {
color: vec3<f32>,
}

fn revive_canonical(ro: ResampleOutput) -> LiveReservoir {
let radiance = select(vec3<f32>(0.0), ro.color / ro.reservoir.contribution_weight, ro.reservoir.contribution_weight > 0.0);
return unpack_reservoir(ro.reservoir, ~0u, radiance);
}

fn finalize_canonical(reservoir: LiveReservoir) -> ResampleOutput {
var ro = ResampleOutput();
ro.reservoir = pack_reservoir(reservoir);
ro.color = ro.reservoir.contribution_weight * reservoir.radiance;
return ro;
}

fn finalize_resampling(
reservoir: ptr<function, LiveReservoir>, color_and_weight: ptr<function, vec4<f32>>,
base: ResampleBase, mis_canonical: f32, rng: ptr<function, RandomState>,
) -> ResampleOutput {
var ro = ResampleOutput();
var canonical = base.canonical;
if (PAIRWISE_MIS && canonical.history > 0.0) {
//TODO: fix the case of `mis_canonical` being too low
canonical.weight_sum *= mis_canonical / canonical.history;
}
canonical.weight_sum *= mis_canonical;
merge_reservoir(reservoir, canonical, random_gen(rng));

if (base.accepted_count > 0.0) {
let effective_history = select((*reservoir).history, base.accepted_count, PAIRWISE_MIS);
let effective_history = select((*reservoir).history, 1.0 + base.accepted_count, PAIRWISE_MIS);
ro.reservoir = pack_reservoir_detail(*reservoir, effective_history);
} else {
ro.reservoir = pack_reservoir(canonical);
}

if (DECOUPLED_SHADING) {
//FIXME: issue with near zero denominator. Do we need do use BASE_CANONICAL_MIS?
let contribution_weight = canonical.weight_sum / max(canonical.selected_target_score * mis_canonical, 0.1);
*color_and_weight += canonical.weight_sum * vec4<f32>(contribution_weight * canonical.radiance, 1.0);
if (canonical.selected_target_score > 0.0) {
let contribution_weight = canonical.weight_sum / canonical.selected_target_score;
*color_and_weight += canonical.weight_sum * vec4<f32>(contribution_weight * canonical.radiance, 1.0);
}
ro.color = (*color_and_weight).xyz / max((*color_and_weight).w, 0.001);
} else {
ro.color = ro.reservoir.contribution_weight * (*reservoir).radiance;
Expand All @@ -441,45 +475,29 @@ fn finalize_resampling(

fn resample_temporal(
surface: Surface, cur_pixel: vec2<i32>, position: vec3<f32>,
rng: ptr<function, RandomState>, enable_debug: bool
rng: ptr<function, RandomState>, debug_len: f32,
) -> ResampleOutput {
if (debug.view_mode == DebugMode_TemporalMatch || debug.view_mode == DebugMode_TemporalMisCanonical || debug.view_mode == DebugMode_TemporalMisError) {
textureStore(out_debug, cur_pixel, vec4<f32>(0.0));
}
if (surface.depth == 0.0) {
return ResampleOutput();
}
let debug_len = select(0.0, surface.depth * 0.2, enable_debug);

// build the canonical sample
var canonical = LiveReservoir();
for (var i = 0u; i < parameters.num_environment_samples; i += 1u) {
var ls: LightSample;
if (parameters.environment_importance_sampling != 0u) {
ls = sample_light_from_environment(rng);
} else {
ls = sample_light_from_sphere(rng);
}

let brdf = evaluate_sample(ls, surface, position, debug_len);
if (brdf > 0.0) {
let other = make_reservoir(ls, 0u, vec3<f32>(brdf));
merge_reservoir(&canonical, other, random_gen(rng));
} else {
bump_reservoir(&canonical, 1.0);
}
}
let canonical = produce_canonical(surface, position, rng, debug_len);

//TODO: find best match in a 2x2 grid
let prev_pixel = vec2<i32>(get_prev_pixel(cur_pixel, position));

let prev_reservoir_index = get_reservoir_index(prev_pixel, prev_camera);
if (parameters.temporal_tap == 0u || prev_reservoir_index < 0) {
return ResampleOutput(pack_reservoir(canonical), vec3<f32>(0.0));
return finalize_canonical(canonical);
}

let prev_reservoir = prev_reservoirs[prev_reservoir_index];
let prev_surface = read_prev_surface(prev_pixel);
// if the surfaces are too different, there is no trust in this sample
if (prev_reservoir.confidence == 0.0 || compare_surfaces(surface, prev_surface) < 0.1) {
return ResampleOutput(pack_reservoir(canonical), vec3<f32>(0.0));
return finalize_canonical(canonical);
}

var reservoir = LiveReservoir();
Expand All @@ -489,25 +507,36 @@ fn resample_temporal(
let prev_dir = get_ray_direction(prev_camera, prev_pixel);
let prev_world_pos = prev_camera.position + prev_surface.depth * prev_dir;
let other = PixelCache(prev_surface, prev_reservoir, prev_world_pos);
let rr = resample(&reservoir, &color_and_weight, base, other, prev_acc_struct, parameters.temporal_history, rng, enable_debug);
let total_samples = 2.0;
let mis_canonical = select(0.0, 1.0 / total_samples, DEFENSIVE_MIS) + rr.mis_canonical;
let rr = resample(&reservoir, &color_and_weight, base, other, prev_acc_struct, parameters.temporal_history, rng, debug_len);
let mis_canonical = 1.0 + rr.mis_canonical;

if (debug.view_mode == DebugMode_TemporalMatch) {
textureStore(out_debug, cur_pixel, vec4<f32>(1.0));
}
if (debug.view_mode == DebugMode_TemporalMisCanonical) {
textureStore(out_debug, cur_pixel, vec4<f32>(mis_canonical / (1.0 + base.accepted_count)));
}
if (debug.view_mode == DebugMode_TemporalMisError) {
let total = mis_canonical + rr.mis_sample;
textureStore(out_debug, cur_pixel, vec4<f32>(abs(total - 1.0 - base.accepted_count)));
}
return finalize_resampling(&reservoir, &color_and_weight, base, mis_canonical, rng);
}

fn resample_spatial(
surface: Surface, cur_pixel: vec2<i32>, position: vec3<f32>,
group_id: vec3<u32>, canonical_stored: StoredReservoir,
rng: ptr<function, RandomState>, enable_debug: bool
group_id: vec3<u32>, canonical: LiveReservoir,
rng: ptr<function, RandomState>, debug_len: f32,
) -> ResampleOutput {
if (surface.depth == 0.0) {
if (debug.view_mode == DebugMode_SpatialMatch || debug.view_mode == DebugMode_SpatialMisCanonical || debug.view_mode == DebugMode_SpatialMisError) {
textureStore(out_debug, cur_pixel, vec4<f32>(0.0));
}
let dir = normalize(position - camera.position);
var ro = ResampleOutput();
ro.color = evaluate_environment(dir);
return ro;
}
let debug_len = select(0.0, surface.depth * 0.2, enable_debug);

// gather the list of neighbors (within the workgroup) to resample.
var accepted_count = 0u;
Expand All @@ -528,20 +557,31 @@ fn resample_spatial(
}
}

let canonical = unpack_reservoir(canonical_stored, ~0u);
var reservoir = LiveReservoir();
let total_samples = 1.0 + f32(accepted_count);
var mis_canonical = select(0.0, 1.0 / total_samples, DEFENSIVE_MIS);
var color_and_weight = vec4<f32>(0.0);
let base = ResampleBase(surface, canonical, position, f32(accepted_count));
var mis_canonical = 1.0;
var mis_sample_sum = 0.0;

// evaluate the MIS of each of the samples versus the canonical one.
for (var lid = 0u; lid < accepted_count; lid += 1u) {
let other = pixel_cache[accepted_local_indices[lid]];
let rr = resample(&reservoir, &color_and_weight, base, other, acc_struct, parameters.spatial_tap_history, rng, enable_debug);
let rr = resample(&reservoir, &color_and_weight, base, other, acc_struct, parameters.spatial_tap_history, rng, debug_len);
mis_canonical += rr.mis_canonical;
mis_sample_sum += rr.mis_sample;
}

if (debug.view_mode == DebugMode_SpatialMatch) {
let value = f32(accepted_count) / max(1.0, f32(parameters.spatial_taps));
textureStore(out_debug, cur_pixel, vec4<f32>(value));
}
if (debug.view_mode == DebugMode_SpatialMisCanonical) {
textureStore(out_debug, cur_pixel, vec4<f32>(mis_canonical / (1.0 + base.accepted_count)));
}
if (debug.view_mode == DebugMode_SpatialMisError) {
let total = mis_canonical + mis_sample_sum;
textureStore(out_debug, cur_pixel, vec4<f32>(abs(total - 1.0 - base.accepted_count)));
}
return finalize_resampling(&reservoir, &color_and_weight, base, mis_canonical, rng);
}

Expand All @@ -553,22 +593,23 @@ fn compute_restir(
if (debug.view_mode == DebugMode_Depth) {
textureStore(out_debug, pixel, vec4<f32>(surface.depth / camera.depth));
}
let ray_dir = get_ray_direction(camera, pixel);
let pixel_index = get_reservoir_index(pixel, camera);

let position = camera.position + surface.depth * ray_dir;
if (debug.view_mode == DebugMode_Normal) {
let normal = qrot(surface.basis, vec3<f32>(0.0, 0.0, 1.0));
textureStore(out_debug, pixel, vec4<f32>(normal, 0.0));
}
let debug_len = select(0.0, surface.depth * 0.2, enable_debug);
let ray_dir = get_ray_direction(camera, pixel);
let pixel_index = get_reservoir_index(pixel, camera);
let position = camera.position + surface.depth * ray_dir;

let temporal = resample_temporal(surface, pixel, position, rng, enable_debug);
let temporal = resample_temporal(surface, pixel, position, rng, debug_len);
pixel_cache[local_index] = PixelCache(surface, temporal.reservoir, position);

// sync with the workgroup to ensure all reservoirs are available.
workgroupBarrier();

let spatial = resample_spatial(surface, pixel, position, group_id, temporal.reservoir, rng, enable_debug);
let temporal_live = revive_canonical(temporal);
let spatial = resample_spatial(surface, pixel, position, group_id, temporal_live, rng, debug_len);
reservoirs[pixel_index] = spatial.reservoir;
return spatial.color;
}
Expand Down Expand Up @@ -598,6 +639,7 @@ fn main(
let enable_debug = all(pixel_coord == vec2<i32>(debug.mouse_pos));
let enable_restir_debug = (debug.draw_flags & DebugDrawFlags_RESTIR) != 0u && enable_debug;
let color = compute_restir(pixel_coord, local_index, group_id, &rng, enable_restir_debug);

if (enable_debug) {
debug_buf.variance.color_sum += color;
debug_buf.variance.color2_sum += color * color;
Expand Down
Loading

0 comments on commit ddee75e

Please sign in to comment.