diff --git a/blade-render/code/atrous.wgsl b/blade-render/code/atrous.wgsl deleted file mode 100644 index 71195499..00000000 --- a/blade-render/code/atrous.wgsl +++ /dev/null @@ -1,53 +0,0 @@ -#include "surface.inc.wgsl" - -struct Params { - extent: vec2, -} - -var params: Params; -var t_flat_normal: texture_2d; -var t_depth: texture_2d; -var input: texture_2d; -var output: texture_storage_2d; - -fn read_surface(pixel: vec2) -> Surface { - var surface = Surface(); - surface.flat_normal = normalize(textureLoad(t_flat_normal, pixel, 0).xyz); - surface.depth = textureLoad(t_depth, pixel, 0).x; - return surface; -} - -const gaussian_weights = vec2(0.44198, 0.27901); - -@compute @workgroup_size(8, 8) -fn main(@builtin(global_invocation_id) global_id: vec3) { - let center = vec2(global_id.xy); - if (any(center >= params.extent)) { - return; - } - - let center_radiance = textureLoad(input, center, 0).xyz; - let center_suf = read_surface(center); - var sum_weight = gaussian_weights[0] * gaussian_weights[0]; - var sum_radiance = center_radiance * sum_weight; - - for (var yy=-1; yy<=1; yy+=1) { - for (var xx=-1; xx<=1; xx+=1) { - let p = center + vec2(xx, yy); - if (all(p == center) || any(p < vec2(0)) || any(p >= params.extent)) { - continue; - } - - //TODO: store in group-shared memory - let surface = read_surface(p); - var weight = gaussian_weights[abs(xx)] * gaussian_weights[abs(yy)]; - weight *= compare_surfaces(center_suf, surface); - let radiance = textureLoad(input, p, 0).xyz; - sum_radiance += weight * radiance; - sum_weight += weight; - } - } - - let radiance = sum_radiance / sum_weight; - textureStore(output, global_id.xy, vec4(radiance, 0.0)); -} diff --git a/blade-render/code/blur.wgsl b/blade-render/code/blur.wgsl new file mode 100644 index 00000000..903c8737 --- /dev/null +++ b/blade-render/code/blur.wgsl @@ -0,0 +1,94 @@ +#include "camera.inc.wgsl" +#include "quaternion.inc.wgsl" +#include "surface.inc.wgsl" + +// Spatio-temporal variance-guided filtering +// https://research.nvidia.com/sites/default/files/pubs/2017-07_Spatiotemporal-Variance-Guided-Filtering%3A//svgf_preprint.pdf + +struct Params { + extent: vec2, + temporal_weight: f32, +} + +var camera: CameraParams; +var prev_camera: CameraParams; +var params: Params; +var t_flat_normal: texture_2d; +var t_prev_flat_normal: texture_2d; +var t_depth: texture_2d; +var t_prev_depth: texture_2d; +var input: texture_2d; +var prev_input: texture_2d; +var output: texture_storage_2d; + +fn read_surface(pixel: vec2) -> Surface { + var surface = Surface(); + surface.flat_normal = normalize(textureLoad(t_flat_normal, pixel, 0).xyz); + surface.depth = textureLoad(t_depth, pixel, 0).x; + return surface; +} +fn read_prev_surface(pixel: vec2) -> Surface { + var surface = Surface(); + surface.flat_normal = normalize(textureLoad(t_prev_flat_normal, pixel, 0).xyz); + surface.depth = textureLoad(t_prev_depth, pixel, 0).x; + return surface; +} + +@compute @workgroup_size(8, 8) +fn temporal_accum(@builtin(global_invocation_id) global_id: vec3) { + let pixel = vec2(global_id.xy); + if (any(pixel >= params.extent)) { + return; + } + //TODO: use motion vectors + let cur_radiance = textureLoad(input, pixel, 0).xyz; + let surface = read_surface(pixel); + let pos_world = camera.position + surface.depth * get_ray_direction(camera, pixel); + let prev_pixel = get_projected_pixel(prev_camera, pos_world); + var prev_radiance = cur_radiance; + var history_weight = 1.0 - params.temporal_weight; + if (all(prev_pixel >= vec2(0)) && all(prev_pixel < params.extent)) { + prev_radiance = textureLoad(prev_input, prev_pixel, 0).xyz; + let prev_surface = read_prev_surface(prev_pixel); + let projected_distance = length(pos_world - prev_camera.position); + history_weight *= compare_flat_normals(surface.flat_normal, prev_surface.flat_normal); + history_weight *= compare_depths(surface.depth, projected_distance); + } + let radiance = mix(cur_radiance, prev_radiance, history_weight); + textureStore(output, global_id.xy, vec4(radiance, 0.0)); +} + +const gaussian_weights = vec2(0.44198, 0.27901); + +@compute @workgroup_size(8, 8) +fn atrous(@builtin(global_invocation_id) global_id: vec3) { + let center = vec2(global_id.xy); + if (any(center >= params.extent)) { + return; + } + + let center_radiance = textureLoad(input, center, 0).xyz; + let center_suf = read_surface(center); + var sum_weight = gaussian_weights[0] * gaussian_weights[0]; + var sum_radiance = center_radiance * sum_weight; + + for (var yy=-1; yy<=1; yy+=1) { + for (var xx=-1; xx<=1; xx+=1) { + let p = center + vec2(xx, yy); + if (all(p == center) || any(p < vec2(0)) || any(p >= params.extent)) { + continue; + } + + //TODO: store in group-shared memory + let surface = read_surface(p); + var weight = gaussian_weights[abs(xx)] * gaussian_weights[abs(yy)]; + //weight *= compare_surfaces(center_suf, surface); + let radiance = textureLoad(input, p, 0).xyz; + sum_radiance += weight * radiance; + sum_weight += weight; + } + } + + let radiance = sum_radiance / sum_weight; + textureStore(output, global_id.xy, vec4(radiance, 0.0)); +} diff --git a/blade-render/code/surface.inc.wgsl b/blade-render/code/surface.inc.wgsl index 892f97e5..d6c70fcc 100644 --- a/blade-render/code/surface.inc.wgsl +++ b/blade-render/code/surface.inc.wgsl @@ -4,11 +4,19 @@ struct Surface { depth: f32, } +fn compare_flat_normals(a: vec3, b: vec3) -> f32 { + return smoothstep(0.4, 0.9, dot(a, b)); +} + +fn compare_depths(a: f32, b: f32) -> f32 { + return 1.0 - smoothstep(0.0, 100.0, abs(a - b)); +} + // Return the compatibility rating, where // 1.0 means fully compatible, and // 0.0 means totally incompatible. fn compare_surfaces(a: Surface, b: Surface) -> f32 { - let r_normal = smoothstep(0.4, 0.9, dot(a.flat_normal, b.flat_normal)); - let r_depth = 1.0 - smoothstep(0.0, 100.0, abs(a.depth - b.depth)); + let r_normal = compare_flat_normals(a.flat_normal, b.flat_normal); + let r_depth = compare_depths(a.depth, b.depth); return r_normal * r_depth; } diff --git a/blade-render/src/render/mod.rs b/blade-render/src/render/mod.rs index c50f411d..1a99f59d 100644 --- a/blade-render/src/render/mod.rs +++ b/blade-render/src/render/mod.rs @@ -90,6 +90,7 @@ pub struct RayConfig { #[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] pub struct DenoiserConfig { pub num_passes: u32, + pub temporal_weight: f32, } pub struct SelectionInfo { @@ -141,12 +142,13 @@ struct DebugRender { buffer_size: u32, } +#[allow(dead_code)] struct DoubleRenderTarget { texture: blade_graphics::Texture, views: [blade_graphics::TextureView; 2], active: usize, } - +#[allow(dead_code)] impl DoubleRenderTarget { fn new( name: &str, @@ -216,6 +218,8 @@ struct FrameData { flat_normal_view: blade_graphics::TextureView, albedo: blade_graphics::Texture, albedo_view: blade_graphics::TextureView, + light_diffuse: blade_graphics::Texture, + light_diffuse_view: blade_graphics::TextureView, camera_params: CameraParams, } @@ -271,7 +275,7 @@ impl FrameData { encoder.init_texture(basis); let (flat_normal, flat_normal_view) = Self::create_target( - "flat_normal", + "flat-normal", blade_graphics::TextureFormat::Rgba8Snorm, size, gpu, @@ -286,6 +290,10 @@ impl FrameData { ); encoder.init_texture(albedo); + let (light_diffuse, light_diffuse_view) = + Self::create_target("light-diffuse", RADIANCE_FORMAT, size, gpu); + encoder.init_texture(light_diffuse); + Self { reservoir_buf, depth, @@ -296,6 +304,8 @@ impl FrameData { flat_normal_view, albedo, albedo_view, + light_diffuse, + light_diffuse_view, camera_params: CameraParams::default(), } } @@ -310,9 +320,16 @@ impl FrameData { gpu.destroy_texture(self.flat_normal); gpu.destroy_texture_view(self.albedo_view); gpu.destroy_texture(self.albedo); + gpu.destroy_texture_view(self.light_diffuse_view); + gpu.destroy_texture(self.light_diffuse); } } +struct Blur { + temporal_accum_pipeline: blade_graphics::ComputePipeline, + atrous_pipeline: blade_graphics::ComputePipeline, +} + /// Blade Renderer is a comprehensive rendering solution for /// end user applications. /// @@ -326,13 +343,15 @@ pub struct Renderer { config: RenderConfig, shaders: Shaders, frame_data: [FrameData; 2], - lighting_diffuse: DoubleRenderTarget, + light_temp_texture: blade_graphics::Texture, + light_temp_view: blade_graphics::TextureView, + post_proc_input: blade_graphics::TextureView, debug_texture: blade_graphics::Texture, debug_view: blade_graphics::TextureView, fill_pipeline: blade_graphics::ComputePipeline, main_pipeline: blade_graphics::ComputePipeline, - atrous_pipeline: blade_graphics::ComputePipeline, - blit_pipeline: blade_graphics::RenderPipeline, + post_proc_pipeline: blade_graphics::RenderPipeline, + blur: Blur, scene: super::Scene, acceleration_structure: blade_graphics::AccelerationStructure, env_map: EnvironmentMap, @@ -429,13 +448,29 @@ struct MainData { #[repr(C)] #[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)] -struct AtrousParams { +struct BlurParams { extent: [u32; 2], + temporal_weight: f32, + pad: f32, +} + +#[derive(blade_macros::ShaderData)] +struct TemporalAccumData { + camera: CameraParams, + prev_camera: CameraParams, + params: BlurParams, + input: blade_graphics::TextureView, + prev_input: blade_graphics::TextureView, + t_depth: blade_graphics::TextureView, + t_prev_depth: blade_graphics::TextureView, + t_flat_normal: blade_graphics::TextureView, + t_prev_flat_normal: blade_graphics::TextureView, + output: blade_graphics::TextureView, } #[derive(blade_macros::ShaderData)] struct AtrousData { - params: AtrousParams, + params: BlurParams, input: blade_graphics::TextureView, t_flat_normal: blade_graphics::TextureView, t_depth: blade_graphics::TextureView, @@ -503,7 +538,7 @@ struct HitEntry { pub struct Shaders { fill_gbuf: blade_asset::Handle, ray_trace: blade_asset::Handle, - atrous: blade_asset::Handle, + blur: blade_asset::Handle, post_proc: blade_asset::Handle, debug_draw: blade_asset::Handle, debug_blit: blade_asset::Handle, @@ -515,7 +550,7 @@ impl Shaders { let shaders = Self { fill_gbuf: ctx.load_shader("fill-gbuf.wgsl"), ray_trace: ctx.load_shader("ray-trace.wgsl"), - atrous: ctx.load_shader("atrous.wgsl"), + blur: ctx.load_shader("blur.wgsl"), post_proc: ctx.load_shader("post-proc.wgsl"), debug_draw: ctx.load_shader("debug-draw.wgsl"), debug_blit: ctx.load_shader("debug-blit.wgsl"), @@ -527,6 +562,7 @@ impl Shaders { struct ShaderPipelines { fill: blade_graphics::ComputePipeline, main: blade_graphics::ComputePipeline, + temporal_accum: blade_graphics::ComputePipeline, atrous: blade_graphics::ComputePipeline, post_proc: blade_graphics::RenderPipeline, debug_draw: blade_graphics::RenderPipeline, @@ -567,6 +603,19 @@ impl ShaderPipelines { compute: shader.at("main"), }) } + + fn create_temporal_accum( + shader: &blade_graphics::Shader, + gpu: &blade_graphics::Context, + ) -> blade_graphics::ComputePipeline { + let layout = ::layout(); + gpu.create_compute_pipeline(blade_graphics::ComputePipelineDesc { + name: "temporal-accum", + data_layouts: &[&layout], + compute: shader.at("temporal_accum"), + }) + } + fn create_atrous( shader: &blade_graphics::Shader, gpu: &blade_graphics::Context, @@ -575,9 +624,10 @@ impl ShaderPipelines { gpu.create_compute_pipeline(blade_graphics::ComputePipelineDesc { name: "atrous", data_layouts: &[&layout], - compute: shader.at("main"), + compute: shader.at("atrous"), }) } + fn create_post_proc( shader: &blade_graphics::Shader, format: blade_graphics::TextureFormat, @@ -597,6 +647,7 @@ impl ShaderPipelines { depth_stencil: None, }) } + fn create_debug_draw( shader: &blade_graphics::Shader, format: blade_graphics::TextureFormat, @@ -620,6 +671,7 @@ impl ShaderPipelines { }], }) } + fn create_debug_blit( shader: &blade_graphics::Shader, format: blade_graphics::TextureFormat, @@ -648,10 +700,12 @@ impl ShaderPipelines { shader_man: &blade_asset::AssetManager, ) -> Result { let sh_main = shader_man[shaders.ray_trace].raw.as_ref().unwrap(); + let sh_blur = shader_man[shaders.blur].raw.as_ref().unwrap(); Ok(Self { fill: Self::create_gbuf_fill(shader_man[shaders.fill_gbuf].raw.as_ref().unwrap(), gpu), main: Self::create_ray_trace(sh_main, gpu), - atrous: Self::create_atrous(shader_man[shaders.atrous].raw.as_ref().unwrap(), gpu), + temporal_accum: Self::create_temporal_accum(sh_blur, gpu), + atrous: Self::create_atrous(sh_blur, gpu), post_proc: Self::create_post_proc( shader_man[shaders.post_proc].raw.as_ref().unwrap(), config.surface_format, @@ -731,6 +785,10 @@ impl Renderer { FrameData::new(config.screen_size, sp.reservoir_size, encoder, gpu), FrameData::new(config.screen_size, sp.reservoir_size, encoder, gpu), ]; + let (light_temp_texture, light_temp_view) = + FrameData::create_target("light-temp", RADIANCE_FORMAT, config.screen_size, gpu); + encoder.init_texture(light_temp_texture); + let dummy = DummyResources::new(encoder, gpu); let (debug_texture, debug_view) = FrameData::create_target( "debug", @@ -763,20 +821,19 @@ impl Renderer { config: *config, shaders, frame_data, - lighting_diffuse: DoubleRenderTarget::new( - "light/diffuse", - RADIANCE_FORMAT, - config.screen_size, - encoder, - gpu, - ), + light_temp_texture, + light_temp_view, + post_proc_input: blade_graphics::TextureView::default(), debug_texture, debug_view, scene: super::Scene::default(), fill_pipeline: sp.fill, main_pipeline: sp.main, - atrous_pipeline: sp.atrous, - blit_pipeline: sp.post_proc, + post_proc_pipeline: sp.post_proc, + blur: Blur { + temporal_accum_pipeline: sp.temporal_accum, + atrous_pipeline: sp.atrous, + }, acceleration_structure: blade_graphics::AccelerationStructure::default(), env_map: EnvironmentMap::with_pipeline(&dummy, sp.env_preproc), dummy, @@ -800,7 +857,8 @@ impl Renderer { for frame_data in self.frame_data.iter_mut() { frame_data.destroy(gpu); } - self.lighting_diffuse.destroy(gpu); + gpu.destroy_texture(self.light_temp_texture); + gpu.destroy_texture_view(self.light_temp_view); gpu.destroy_texture(self.debug_texture); gpu.destroy_texture_view(self.debug_view); if self.hit_buffer != blade_graphics::Buffer::default() { @@ -835,6 +893,7 @@ impl Renderer { tasks.extend(asset_hub.shaders.hot_reload(&mut self.shaders.fill_gbuf)); tasks.extend(asset_hub.shaders.hot_reload(&mut self.shaders.ray_trace)); + tasks.extend(asset_hub.shaders.hot_reload(&mut self.shaders.blur)); tasks.extend(asset_hub.shaders.hot_reload(&mut self.shaders.post_proc)); tasks.extend(asset_hub.shaders.hot_reload(&mut self.shaders.debug_draw)); tasks.extend(asset_hub.shaders.hot_reload(&mut self.shaders.debug_blit)); @@ -868,14 +927,16 @@ impl Renderer { self.main_pipeline = ShaderPipelines::create_ray_trace(shader, gpu); } } - if self.shaders.atrous != old.atrous { - if let Ok(ref shader) = asset_hub.shaders[self.shaders.atrous].raw { - self.atrous_pipeline = ShaderPipelines::create_atrous(shader, gpu); + if self.shaders.blur != old.blur { + if let Ok(ref shader) = asset_hub.shaders[self.shaders.blur].raw { + self.blur.temporal_accum_pipeline = + ShaderPipelines::create_temporal_accum(shader, gpu); + self.blur.atrous_pipeline = ShaderPipelines::create_atrous(shader, gpu); } } if self.shaders.post_proc != old.post_proc { if let Ok(ref shader) = asset_hub.shaders[self.shaders.post_proc].raw { - self.blit_pipeline = + self.post_proc_pipeline = ShaderPipelines::create_post_proc(shader, self.config.surface_format, gpu); } } @@ -921,9 +982,13 @@ impl Renderer { *frame_data = FrameData::new(size, self.reservoir_size, encoder, gpu); } - self.lighting_diffuse.destroy(gpu); - self.lighting_diffuse = - DoubleRenderTarget::new("light/diffuse", RADIANCE_FORMAT, size, encoder, gpu); + gpu.destroy_texture(self.light_temp_texture); + gpu.destroy_texture_view(self.light_temp_view); + let (light_temp_texture, light_temp_view) = + FrameData::create_target("light-temp", RADIANCE_FORMAT, size, gpu); + encoder.init_texture(light_temp_texture); + self.light_temp_texture = light_temp_texture; + self.light_temp_view = light_temp_view; gpu.destroy_texture(self.debug_texture); gpu.destroy_texture_view(self.debug_view); @@ -1166,6 +1231,7 @@ impl Renderer { self.frame_index += 1; self.frame_data.swap(0, 1); self.frame_data[0].camera_params = self.make_camera_params(camera); + self.post_proc_input = self.frame_data[0].light_diffuse_view; } fn make_camera_params(&self, camera: &super::Camera) -> CameraParams { @@ -1263,7 +1329,7 @@ impl Renderer { debug_buf: self.debug.buffer.into(), reservoirs: cur.reservoir_buf.into(), prev_reservoirs: prev.reservoir_buf.into(), - out_diffuse: self.lighting_diffuse.cur(), + out_diffuse: cur.light_diffuse_view, out_debug: self.debug_view, }, ); @@ -1272,30 +1338,70 @@ impl Renderer { } pub fn denoise( - &mut self, + &mut self, //TODO: borrow immutably command_encoder: &mut blade_graphics::CommandEncoder, denoiser_config: DenoiserConfig, ) { - let cur = self.frame_data.first().unwrap(); - for _ in 0..denoiser_config.num_passes { + let params = BlurParams { + extent: [self.screen_size.width, self.screen_size.height], + temporal_weight: denoiser_config.temporal_weight, + pad: 0.0, + }; + if denoiser_config.temporal_weight < 1.0 { + let cur = self.frame_data.first().unwrap(); + let prev = self.frame_data.last().unwrap(); if let mut pass = command_encoder.compute() { - self.lighting_diffuse.swap(); - let mut pc = pass.with(&self.atrous_pipeline); - let groups = self.atrous_pipeline.get_dispatch_for(self.screen_size); + let mut pc = pass.with(&self.blur.temporal_accum_pipeline); + let groups = self.blur.atrous_pipeline.get_dispatch_for(self.screen_size); pc.bind( 0, - &AtrousData { - params: AtrousParams { - extent: [self.screen_size.width, self.screen_size.height], - }, - input: self.lighting_diffuse.prev(), - t_flat_normal: cur.flat_normal_view, + &TemporalAccumData { + camera: cur.camera_params, + prev_camera: prev.camera_params, + params, + input: cur.light_diffuse_view, + prev_input: prev.light_diffuse_view, t_depth: cur.depth_view, - output: self.lighting_diffuse.cur(), + t_prev_depth: prev.depth_view, + t_flat_normal: cur.flat_normal_view, + t_prev_flat_normal: prev.flat_normal_view, + output: self.light_temp_view, }, ); pc.dispatch(groups); } + + // make it so `cur.light_diffuse_view` always contains the fresh reprojection result + let cur_mut = self.frame_data.first_mut().unwrap(); + mem::swap(&mut self.light_temp_view, &mut cur_mut.light_diffuse_view); + mem::swap(&mut self.light_temp_texture, &mut cur_mut.light_diffuse); + } + + { + let cur = self.frame_data.first().unwrap(); + let prev = self.frame_data.last().unwrap(); + self.post_proc_input = cur.light_diffuse_view; + //Note: we no longer need `prev.light_diffuse_view` so reusing it here + let mut targets = [self.light_temp_view, prev.light_diffuse_view]; + for _ in 0..denoiser_config.num_passes { + if let mut pass = command_encoder.compute() { + let mut pc = pass.with(&self.blur.atrous_pipeline); + let groups = self.blur.atrous_pipeline.get_dispatch_for(self.screen_size); + pc.bind( + 0, + &AtrousData { + params, + input: self.post_proc_input, + t_flat_normal: cur.flat_normal_view, + t_depth: cur.depth_view, + output: targets[0], + }, + ); + pc.dispatch(groups); + self.post_proc_input = targets[0]; + targets.swap(0, 1); // rotate the views + } + } } } @@ -1308,12 +1414,12 @@ impl Renderer { ) { let pp = &self.scene.post_processing; let cur = self.frame_data.first().unwrap(); - if let mut pc = pass.with(&self.blit_pipeline) { + if let mut pc = pass.with(&self.post_proc_pipeline) { pc.bind( 0, &PostProcData { t_albedo: cur.albedo_view, - light_diffuse: self.lighting_diffuse.cur(), + light_diffuse: self.post_proc_input, t_debug: self.debug_view, tone_map_params: ToneMapParams { mode: mode as u32, diff --git a/examples/scene/main.rs b/examples/scene/main.rs index 546d0206..ccd0d4d6 100644 --- a/examples/scene/main.rs +++ b/examples/scene/main.rs @@ -84,6 +84,7 @@ struct Example { last_render_time: time::Instant, render_times: VecDeque, ray_config: blade_render::RayConfig, + denoiser_enabled: bool, denoiser_config: blade_render::DenoiserConfig, debug_blit: Option, debug_blit_input: DebugBlitInput, @@ -231,7 +232,11 @@ impl Example { spatial_tap_history: 5, spatial_radius: 10, }, - denoiser_config: blade_render::DenoiserConfig { num_passes: 5 }, + denoiser_enabled: true, + denoiser_config: blade_render::DenoiserConfig { + num_passes: 5, + temporal_weight: 0.1, + }, debug_blit: None, debug_blit_input: DebugBlitInput::None, workers, @@ -323,7 +328,9 @@ impl Example { self.need_accumulation_reset = false; self.renderer .ray_trace(command_encoder, self.debug, self.ray_config); - self.renderer.denoise(command_encoder, self.denoiser_config); + if self.denoiser_enabled { + self.renderer.denoise(command_encoder, self.denoiser_config); + } } let frame = self.context.acquire_frame(); @@ -600,7 +607,12 @@ impl Example { egui::CollapsingHeader::new("Denoise") .default_open(true) .show(ui, |ui| { + ui.checkbox(&mut self.denoiser_enabled, "Enable"); let dc = &mut self.denoiser_config; + ui.add( + egui::Slider::new(&mut dc.temporal_weight, 0.0..=1.0f32) + .text("Temporal weight"), + ); ui.add(egui::Slider::new(&mut dc.num_passes, 0..=15u32).text("A-trous passes")); });