diff --git a/wgpu-core/src/command/ray_tracing.rs b/wgpu-core/src/command/ray_tracing.rs index 3fbe48b78a..e94458f9db 100644 --- a/wgpu-core/src/command/ray_tracing.rs +++ b/wgpu-core/src/command/ray_tracing.rs @@ -14,18 +14,28 @@ use crate::{ FastHashSet, }; -use wgt::{math::align_to, BufferUsages}; +use wgt::{math::align_to, BufferUsages, BufferAddress}; +use super::{BakedCommands, CommandBufferMutable, CommandEncoderError}; use crate::lock::rank; use crate::ray_tracing::BlasTriangleGeometry; use crate::resource::{Buffer, Labeled, StagingBuffer, Trackable}; +use crate::snatch::SnatchGuard; +use crate::storage::Storage; use crate::track::PendingTransition; -use hal::{BufferUses, CommandEncoder, Device}; +use hal::{Api, BufferUses, CommandEncoder, Device}; use std::ops::Deref; use std::sync::Arc; use std::{cmp::max, iter, num::NonZeroU64, ops::Range, ptr}; -use super::{BakedCommands, CommandEncoderError}; +type BufferStorage<'a, A> = Vec<( + Arc>, + Option>, + Option<(Arc>, Option>)>, + Option<(Arc>, Option>)>, + BlasTriangleGeometry<'a>, + Option>>, +)>; // This should be queried from the device, maybe the the hal api should pre aline it, since I am unsure how else we can idiomatically get this value. const SCRATCH_BUFFER_ALIGNMENT: u32 = 256; @@ -146,338 +156,30 @@ impl Global { )>::new(); let mut scratch_buffer_blas_size = 0; - let mut blas_storage = Vec::<(&Blas, hal::AccelerationStructureEntries, u64)>::new(); + let mut blas_storage = + Vec::<(Arc>, hal::AccelerationStructureEntries, u64)>::new(); let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); - for entry in blas_iter { - let blas = cmd_buf_data.trackers.blas_s.insert_single( - blas_guard - .get(entry.blas_id) - .map_err(|_| BuildAccelerationStructureError::InvalidBlasId)? - .clone(), - ); - - if blas.raw.is_none() { - return Err(BuildAccelerationStructureError::InvalidBlas( - blas.error_ident(), - )); - } - cmd_buf_data.blas_actions.push(BlasAction { - blas: blas.clone(), - kind: crate::ray_tracing::BlasActionKind::Build(build_command_index), - }); - - match entry.geometries { - BlasGeometries::TriangleGeometries(triangle_geometries) => { - for (i, mesh) in triangle_geometries.enumerate() { - let size_desc = match &blas.sizes { - wgt::BlasGeometrySizeDescriptors::Triangles { desc } => desc, - }; - if i >= size_desc.len() { - return Err( - BuildAccelerationStructureError::IncompatibleBlasBuildSizes( - blas.error_ident(), - ), - ); - } - let size_desc = &size_desc[i]; - - if size_desc.flags != mesh.size.flags - || size_desc.vertex_count < mesh.size.vertex_count - || size_desc.vertex_format != mesh.size.vertex_format - || size_desc.index_count.is_none() != mesh.size.index_count.is_none() - || (size_desc.index_count.is_none() - || size_desc.index_count.unwrap() < mesh.size.index_count.unwrap()) - || size_desc.index_format.is_none() != mesh.size.index_format.is_none() - || (size_desc.index_format.is_none() - || size_desc.index_format.unwrap() - != mesh.size.index_format.unwrap()) - { - return Err( - BuildAccelerationStructureError::IncompatibleBlasBuildSizes( - blas.error_ident(), - ), - ); - } - - if size_desc.index_count.is_some() && mesh.index_buffer.is_none() { - return Err(BuildAccelerationStructureError::MissingIndexBuffer( - blas.error_ident(), - )); - } - let vertex_buffer = match buffer_guard.get(mesh.vertex_buffer) { - Ok(buffer) => buffer, - Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId), - }; - let vertex_pending = cmd_buf_data.trackers.buffers.set_single( - vertex_buffer, - BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, - ); - let index_data = if let Some(index_id) = mesh.index_buffer { - let index_buffer = match buffer_guard.get(index_id) { - Ok(buffer) => buffer, - Err(_) => { - return Err(BuildAccelerationStructureError::InvalidBufferId) - } - }; - if mesh.index_buffer_offset.is_none() - || mesh.size.index_count.is_none() - || mesh.size.index_count.is_none() - { - return Err( - BuildAccelerationStructureError::MissingAssociatedData( - index_buffer.error_ident(), - ), - ); - } - let data = cmd_buf_data.trackers.buffers.set_single( - index_buffer, - hal::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, - ); - Some((index_buffer.clone(), data)) - } else { - None - }; - let transform_data = if let Some(transform_id) = mesh.transform_buffer { - let transform_buffer = match buffer_guard.get(transform_id) { - Ok(buffer) => buffer, - Err(_) => { - return Err(BuildAccelerationStructureError::InvalidBufferId) - } - }; - if mesh.transform_buffer_offset.is_none() { - return Err( - BuildAccelerationStructureError::MissingAssociatedData( - transform_buffer.error_ident(), - ), - ); - } - let data = cmd_buf_data.trackers.buffers.set_single( - match buffer_guard.get(transform_id) { - Ok(buffer) => buffer, - Err(_) => { - return Err( - BuildAccelerationStructureError::InvalidBufferId, - ) - } - }, - BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, - ); - Some((transform_buffer.clone(), data)) - } else { - None - }; - - buf_storage.push(( - vertex_buffer.clone(), - vertex_pending, - index_data, - transform_data, - mesh, - None, - )) - } - if let Some(last) = buf_storage.last_mut() { - last.5 = Some(blas.clone()); - } - } - } - } + iter_blas( + blas_iter, + cmd_buf_data, + build_command_index, + &buffer_guard, + &blas_guard, + &mut buf_storage, + )?; - let mut triangle_entries = Vec::>::new(); let snatch_guard = device.snatchable_lock.read(); - for buf in &mut buf_storage { - let mesh = &buf.4; - let vertex_buffer = { - let vertex_buffer = buf.0.as_ref(); - let vertex_raw = vertex_buffer - .raw - .get(&snatch_guard) - .ok_or(BuildAccelerationStructureError::InvalidBufferId)?; - if !vertex_buffer.usage.contains(BufferUsages::BLAS_INPUT) { - return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag( - vertex_buffer.error_ident(), - )); - } - if let Some(barrier) = buf - .1 - .take() - .map(|pending| pending.into_hal(vertex_buffer, &snatch_guard)) - { - input_barriers.push(barrier); - } - if vertex_buffer.size - < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride - { - return Err(BuildAccelerationStructureError::InsufficientBufferSize( - vertex_buffer.error_ident(), - vertex_buffer.size, - (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride, - )); - } - let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride; - cmd_buf_data.buffer_memory_init_actions.extend( - vertex_buffer.initialization_status.read().create_action( - buffer_guard.get(mesh.vertex_buffer).unwrap(), - vertex_buffer_offset - ..(vertex_buffer_offset - + mesh.size.vertex_count as u64 * mesh.vertex_stride), - MemoryInitKind::NeedsInitializedMemory, - ), - ); - vertex_raw - }; - let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) = buf.2 { - let index_id = mesh.index_buffer.as_ref().unwrap(); - let index_raw = index_buffer - .raw - .get(&snatch_guard) - .ok_or(BuildAccelerationStructureError::InvalidBufferId)?; - if !index_buffer.usage.contains(BufferUsages::BLAS_INPUT) { - return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag( - index_buffer.error_ident(), - )); - } - if let Some(barrier) = index_pending - .take() - .map(|pending| pending.into_hal(index_buffer, &snatch_guard)) - { - input_barriers.push(barrier); - } - let index_stride = match mesh.size.index_format.unwrap() { - wgt::IndexFormat::Uint16 => 2, - wgt::IndexFormat::Uint32 => 4, - }; - if mesh.index_buffer_offset.unwrap() % index_stride != 0 { - return Err(BuildAccelerationStructureError::UnalignedIndexBufferOffset( - index_buffer.error_ident(), - )); - } - let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride; - - if mesh.size.index_count.unwrap() % 3 != 0 { - return Err(BuildAccelerationStructureError::InvalidIndexCount( - index_buffer.error_ident(), - mesh.size.index_count.unwrap(), - )); - } - if index_buffer.size - < mesh.size.index_count.unwrap() as u64 * index_stride - + mesh.index_buffer_offset.unwrap() - { - return Err(BuildAccelerationStructureError::InsufficientBufferSize( - index_buffer.error_ident(), - index_buffer.size, - mesh.size.index_count.unwrap() as u64 * index_stride - + mesh.index_buffer_offset.unwrap(), - )); - } - - cmd_buf_data.buffer_memory_init_actions.extend( - index_buffer.initialization_status.read().create_action( - match buffer_guard.get(*index_id) { - Ok(buffer) => buffer, - Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId), - }, - mesh.index_buffer_offset.unwrap() - ..(mesh.index_buffer_offset.unwrap() + index_buffer_size), - MemoryInitKind::NeedsInitializedMemory, - ), - ); - Some(index_raw) - } else { - None - }; - let transform_buffer = - if let Some((ref mut transform_buffer, ref mut transform_pending)) = buf.3 { - if mesh.transform_buffer_offset.is_none() { - return Err(BuildAccelerationStructureError::MissingAssociatedData( - transform_buffer.error_ident(), - )); - } - let transform_raw = transform_buffer.raw.get(&snatch_guard).ok_or( - BuildAccelerationStructureError::InvalidBuffer( - transform_buffer.error_ident(), - ), - )?; - if !transform_buffer.usage.contains(BufferUsages::BLAS_INPUT) { - return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag( - transform_buffer.error_ident(), - )); - } - if let Some(barrier) = transform_pending - .take() - .map(|pending| pending.into_hal(transform_buffer, &snatch_guard)) - { - input_barriers.push(barrier); - } - if mesh.transform_buffer_offset.unwrap() % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 - { - return Err( - BuildAccelerationStructureError::UnalignedTransformBufferOffset( - transform_buffer.error_ident(), - ), - ); - } - if transform_buffer.size < 48 + mesh.transform_buffer_offset.unwrap() { - return Err(BuildAccelerationStructureError::InsufficientBufferSize( - transform_buffer.error_ident(), - transform_buffer.size, - 48 + mesh.transform_buffer_offset.unwrap(), - )); - } - cmd_buf_data.buffer_memory_init_actions.extend( - transform_buffer.initialization_status.read().create_action( - transform_buffer, - mesh.transform_buffer_offset.unwrap() - ..(mesh.index_buffer_offset.unwrap() + 48), - MemoryInitKind::NeedsInitializedMemory, - ), - ); - Some(transform_raw) - } else { - None - }; - - let triangles = hal::AccelerationStructureTriangles { - vertex_buffer: Some(vertex_buffer), - vertex_format: mesh.size.vertex_format, - first_vertex: mesh.first_vertex, - vertex_count: mesh.size.vertex_count, - vertex_stride: mesh.vertex_stride, - indices: index_buffer.map(|index_buffer| { - hal::AccelerationStructureTriangleIndices:: { - format: mesh.size.index_format.unwrap(), - buffer: Some(index_buffer), - offset: mesh.index_buffer_offset.unwrap() as u32, - count: mesh.size.index_count.unwrap(), - } - }), - transform: transform_buffer.map(|transform_buffer| { - hal::AccelerationStructureTriangleTransform { - buffer: transform_buffer, - offset: mesh.transform_buffer_offset.unwrap() as u32, - } - }), - flags: mesh.size.flags, - }; - triangle_entries.push(triangles); - if let Some(blas) = buf.5.as_ref() { - let scratch_buffer_offset = scratch_buffer_blas_size; - scratch_buffer_blas_size += align_to( - blas.size_info.build_scratch_size as u32, - SCRATCH_BUFFER_ALIGNMENT, - ) as u64; - - blas_storage.push(( - blas, - hal::AccelerationStructureEntries::Triangles(triangle_entries), - scratch_buffer_offset, - )); - triangle_entries = Vec::new(); - } - } + iter_buffers( + &mut buf_storage, + &snatch_guard, + &mut input_barriers, + cmd_buf_data, + &buffer_guard, + &mut scratch_buffer_blas_size, + &mut blas_storage, + )?; let mut scratch_buffer_tlas_size = 0; let mut tlas_storage = Vec::<(&Tlas, hal::AccelerationStructureEntries, u64)>::new(); @@ -568,35 +270,21 @@ impl Global { .create_buffer(&hal::BufferDescriptor { label: Some("(wgpu) scratch buffer"), size: max(scratch_buffer_blas_size, scratch_buffer_tlas_size), - usage: hal::BufferUses::ACCELERATION_STRUCTURE_SCRATCH, + usage: hal::BufferUses::ACCELERATION_STRUCTURE_SCRATCH | BufferUses::MAP_WRITE, memory_flags: hal::MemoryFlags::empty(), }) - .unwrap() + .map_err(crate::device::DeviceError::from)? }; let scratch_buffer_barrier = hal::BufferBarrier:: { buffer: &scratch_buffer, - usage: hal::BufferUses::ACCELERATION_STRUCTURE_SCRATCH - ..hal::BufferUses::ACCELERATION_STRUCTURE_SCRATCH, + usage: BufferUses::ACCELERATION_STRUCTURE_SCRATCH + ..BufferUses::ACCELERATION_STRUCTURE_SCRATCH, }; - let blas_descriptors = - blas_storage - .iter() - .map(|&(blas, ref entries, ref scratch_buffer_offset)| { - if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate { - log::info!("only rebuild implemented") - } - hal::BuildAccelerationStructureDescriptor { - entries, - mode: hal::AccelerationStructureBuildMode::Build, - flags: blas.flags, - source_acceleration_structure: None, - destination_acceleration_structure: blas.raw.as_ref().unwrap(), - scratch_buffer: &scratch_buffer, - scratch_buffer_offset: *scratch_buffer_offset, - } - }); + let blas_descriptors = blas_storage + .iter() + .map(|storage| map_blas(storage, &scratch_buffer)); let tlas_descriptors = tlas_storage @@ -620,41 +308,19 @@ impl Global { let tlas_present = !tlas_storage.is_empty(); let cmd_buf_raw = cmd_buf_data.encoder.open()?; - unsafe { - cmd_buf_raw.transition_buffers(input_barriers.into_iter()); - - if blas_present { - cmd_buf_raw.place_acceleration_structure_barrier( - hal::AccelerationStructureBarrier { - usage: hal::AccelerationStructureUses::BUILD_INPUT - ..hal::AccelerationStructureUses::BUILD_OUTPUT, - }, - ); - - cmd_buf_raw - .build_acceleration_structures(blas_storage.len() as u32, blas_descriptors); - } - - if blas_present && tlas_present { - cmd_buf_raw.transition_buffers(iter::once(scratch_buffer_barrier)); - } - let mut source_usage = hal::AccelerationStructureUses::empty(); - let mut destination_usage = hal::AccelerationStructureUses::empty(); - if blas_present { - source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT; - destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT - } - if tlas_present { - source_usage |= hal::AccelerationStructureUses::SHADER_INPUT; - destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT; - } - - cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier { - usage: source_usage..destination_usage, - }); + build_blas( + cmd_buf_raw, + blas_present, + tlas_present, + input_barriers, + blas_storage.len() as u32, + blas_descriptors, + scratch_buffer_barrier, + ); - if tlas_present { + if tlas_present { + unsafe { cmd_buf_raw .build_acceleration_structures(tlas_storage.len() as u32, tlas_descriptors); @@ -666,6 +332,7 @@ impl Global { ); } } + let scratch_mapping = unsafe { device .raw() @@ -836,326 +503,30 @@ impl Global { )>::new(); let mut scratch_buffer_blas_size = 0; - let mut blas_storage = Vec::<(&Blas, hal::AccelerationStructureEntries, u64)>::new(); + let mut blas_storage = + Vec::<(Arc>, hal::AccelerationStructureEntries, u64)>::new(); let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); - for entry in blas_iter { - let blas = blas_guard - .get(entry.blas_id) - .map_err(|_| BuildAccelerationStructureError::InvalidBlasId)?; - cmd_buf_data.trackers.blas_s.insert_single(blas.clone()); - - if blas.raw.is_none() { - return Err(BuildAccelerationStructureError::InvalidBlas( - blas.error_ident(), - )); - } - cmd_buf_data.blas_actions.push(BlasAction { - blas: blas.clone(), - kind: crate::ray_tracing::BlasActionKind::Build(build_command_index), - }); + iter_blas( + blas_iter, + cmd_buf_data, + build_command_index, + &buffer_guard, + &blas_guard, + &mut buf_storage, + )?; - match entry.geometries { - BlasGeometries::TriangleGeometries(triangle_geometries) => { - for (i, mesh) in triangle_geometries.enumerate() { - let size_desc = match &blas.sizes { - wgt::BlasGeometrySizeDescriptors::Triangles { desc } => desc, - }; - if i >= size_desc.len() { - return Err( - BuildAccelerationStructureError::IncompatibleBlasBuildSizes( - blas.error_ident(), - ), - ); - } - let size_desc = &size_desc[i]; - - if size_desc.flags != mesh.size.flags - || size_desc.vertex_count < mesh.size.vertex_count - || size_desc.vertex_format != mesh.size.vertex_format - || size_desc.index_count.is_none() != mesh.size.index_count.is_none() - || (size_desc.index_count.is_none() - || size_desc.index_count.unwrap() < mesh.size.index_count.unwrap()) - || size_desc.index_format.is_none() != mesh.size.index_format.is_none() - || (size_desc.index_format.is_none() - || size_desc.index_format.unwrap() - != mesh.size.index_format.unwrap()) - { - return Err( - BuildAccelerationStructureError::IncompatibleBlasBuildSizes( - blas.error_ident(), - ), - ); - } - - if size_desc.index_count.is_some() && mesh.index_buffer.is_none() { - return Err(BuildAccelerationStructureError::MissingIndexBuffer( - blas.error_ident(), - )); - } - let vertex_buffer = match buffer_guard.get(mesh.vertex_buffer) { - Ok(buffer) => buffer, - Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId), - }; - let vertex_pending = cmd_buf_data.trackers.buffers.set_single( - vertex_buffer, - BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, - ); - let index_data = if let Some(index_id) = mesh.index_buffer { - let index_buffer = match buffer_guard.get(index_id) { - Ok(buffer) => buffer, - Err(_) => { - return Err(BuildAccelerationStructureError::InvalidBufferId) - } - }; - if mesh.index_buffer_offset.is_none() - || mesh.size.index_count.is_none() - || mesh.size.index_count.is_none() - { - return Err( - BuildAccelerationStructureError::MissingAssociatedData( - index_buffer.error_ident(), - ), - ); - } - let data = cmd_buf_data.trackers.buffers.set_single( - index_buffer, - hal::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, - ); - Some((index_buffer.clone(), data)) - } else { - None - }; - let transform_data = if let Some(transform_id) = mesh.transform_buffer { - let transform_buffer = match buffer_guard.get(transform_id) { - Ok(buffer) => buffer, - Err(_) => { - return Err(BuildAccelerationStructureError::InvalidBufferId) - } - }; - if mesh.transform_buffer_offset.is_none() { - return Err( - BuildAccelerationStructureError::MissingAssociatedData( - transform_buffer.error_ident(), - ), - ); - } - let data = cmd_buf_data.trackers.buffers.set_single( - transform_buffer, - BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, - ); - Some((transform_buffer.clone(), data)) - } else { - None - }; - - buf_storage.push(( - vertex_buffer.clone(), - vertex_pending, - index_data, - transform_data, - mesh, - None, - )) - } - - if let Some(last) = buf_storage.last_mut() { - last.5 = Some(blas.clone()); - } - } - } - } - - let mut triangle_entries = Vec::>::new(); let snatch_guard = device.snatchable_lock.read(); - for buf in &mut buf_storage { - let mesh = &buf.4; - let vertex_buffer = { - let vertex_buffer = buf.0.as_ref(); - let vertex_raw = vertex_buffer.raw.get(&snatch_guard).ok_or( - BuildAccelerationStructureError::InvalidBuffer(vertex_buffer.error_ident()), - )?; - if !vertex_buffer.usage.contains(BufferUsages::BLAS_INPUT) { - return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag( - vertex_buffer.error_ident(), - )); - } - if let Some(barrier) = buf - .1 - .take() - .map(|pending| pending.into_hal(vertex_buffer, &snatch_guard)) - { - input_barriers.push(barrier); - } - if vertex_buffer.size - < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride - { - return Err(BuildAccelerationStructureError::InsufficientBufferSize( - vertex_buffer.error_ident(), - vertex_buffer.size, - (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride, - )); - } - let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride; - cmd_buf_data.buffer_memory_init_actions.extend( - vertex_buffer.initialization_status.read().create_action( - buffer_guard.get(mesh.vertex_buffer).unwrap(), - vertex_buffer_offset - ..(vertex_buffer_offset - + mesh.size.vertex_count as u64 * mesh.vertex_stride), - MemoryInitKind::NeedsInitializedMemory, - ), - ); - vertex_raw - }; - let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) = buf.2 { - let index_raw = index_buffer.raw.get(&snatch_guard).ok_or( - BuildAccelerationStructureError::InvalidBuffer(index_buffer.error_ident()), - )?; - if !index_buffer.usage.contains(BufferUsages::BLAS_INPUT) { - return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag( - index_buffer.error_ident(), - )); - } - if let Some(barrier) = index_pending - .take() - .map(|pending| pending.into_hal(index_buffer, &snatch_guard)) - { - input_barriers.push(barrier); - } - let index_stride = match mesh.size.index_format.unwrap() { - wgt::IndexFormat::Uint16 => 2, - wgt::IndexFormat::Uint32 => 4, - }; - if mesh.index_buffer_offset.unwrap() % index_stride != 0 { - return Err(BuildAccelerationStructureError::UnalignedIndexBufferOffset( - index_buffer.error_ident(), - )); - } - let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride; - - if mesh.size.index_count.unwrap() % 3 != 0 { - return Err(BuildAccelerationStructureError::InvalidIndexCount( - index_buffer.error_ident(), - mesh.size.index_count.unwrap(), - )); - } - if index_buffer.size - < mesh.size.index_count.unwrap() as u64 * index_stride - + mesh.index_buffer_offset.unwrap() - { - return Err(BuildAccelerationStructureError::InsufficientBufferSize( - index_buffer.error_ident(), - index_buffer.size, - mesh.size.index_count.unwrap() as u64 * index_stride - + mesh.index_buffer_offset.unwrap(), - )); - } - - cmd_buf_data.buffer_memory_init_actions.extend( - index_buffer.initialization_status.read().create_action( - index_buffer, - mesh.index_buffer_offset.unwrap() - ..(mesh.index_buffer_offset.unwrap() + index_buffer_size), - MemoryInitKind::NeedsInitializedMemory, - ), - ); - Some(index_raw) - } else { - None - }; - let transform_buffer = - if let Some((ref mut transform_buffer, ref mut transform_pending)) = buf.3 { - let transform_id = mesh.transform_buffer.as_ref().unwrap(); - if mesh.transform_buffer_offset.is_none() { - return Err(BuildAccelerationStructureError::MissingAssociatedData( - transform_buffer.error_ident(), - )); - } - let transform_raw = transform_buffer.raw.get(&snatch_guard).ok_or( - BuildAccelerationStructureError::InvalidBuffer( - transform_buffer.error_ident(), - ), - )?; - if !transform_buffer.usage.contains(BufferUsages::BLAS_INPUT) { - return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag( - transform_buffer.error_ident(), - )); - } - if let Some(barrier) = transform_pending - .take() - .map(|pending| pending.into_hal(transform_buffer, &snatch_guard)) - { - input_barriers.push(barrier); - } - if mesh.transform_buffer_offset.unwrap() % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 - { - return Err( - BuildAccelerationStructureError::UnalignedTransformBufferOffset( - transform_buffer.error_ident(), - ), - ); - } - if transform_buffer.size < 48 + mesh.transform_buffer_offset.unwrap() { - return Err(BuildAccelerationStructureError::InsufficientBufferSize( - transform_buffer.error_ident(), - transform_buffer.size, - 48 + mesh.transform_buffer_offset.unwrap(), - )); - } - cmd_buf_data.buffer_memory_init_actions.extend( - transform_buffer.initialization_status.read().create_action( - buffer_guard.get(*transform_id).unwrap(), - mesh.transform_buffer_offset.unwrap() - ..(mesh.index_buffer_offset.unwrap() + 48), - MemoryInitKind::NeedsInitializedMemory, - ), - ); - Some(transform_raw) - } else { - None - }; - - let triangles = hal::AccelerationStructureTriangles { - vertex_buffer: Some(vertex_buffer), - vertex_format: mesh.size.vertex_format, - first_vertex: mesh.first_vertex, - vertex_count: mesh.size.vertex_count, - vertex_stride: mesh.vertex_stride, - indices: index_buffer.map(|index_buffer| { - hal::AccelerationStructureTriangleIndices:: { - format: mesh.size.index_format.unwrap(), - buffer: Some(index_buffer), - offset: mesh.index_buffer_offset.unwrap() as u32, - count: mesh.size.index_count.unwrap(), - } - }), - transform: transform_buffer.map(|transform_buffer| { - hal::AccelerationStructureTriangleTransform { - buffer: transform_buffer, - offset: mesh.transform_buffer_offset.unwrap() as u32, - } - }), - flags: mesh.size.flags, - }; - triangle_entries.push(triangles); - if let Some(blas) = buf.5.as_ref() { - let scratch_buffer_offset = scratch_buffer_blas_size; - scratch_buffer_blas_size += align_to( - blas.size_info.build_scratch_size as u32, - SCRATCH_BUFFER_ALIGNMENT, - ) as u64; - - blas_storage.push(( - blas, - hal::AccelerationStructureEntries::Triangles(triangle_entries), - scratch_buffer_offset, - )); - triangle_entries = Vec::new(); - } - } - + iter_buffers( + &mut buf_storage, + &snatch_guard, + &mut input_barriers, + cmd_buf_data, + &buffer_guard, + &mut scratch_buffer_blas_size, + &mut blas_storage, + )?; let mut tlas_lock_store = Vec::<( RwLockReadGuard>, Option, @@ -1258,17 +629,6 @@ impl Global { return Ok(()); } - let scratch_buffer = unsafe { - device - .raw() - .create_buffer(&hal::BufferDescriptor { - label: Some("(wgpu) scratch buffer"), - size: max(scratch_buffer_blas_size, scratch_buffer_tlas_size), - usage: hal::BufferUses::ACCELERATION_STRUCTURE_SCRATCH | BufferUses::MAP_WRITE, - memory_flags: hal::MemoryFlags::empty(), - }) - .map_err(crate::device::DeviceError::from)? - }; let staging_buffer = if !instance_buffer_staging_source.is_empty() { unsafe { let staging_buffer = device @@ -1308,23 +668,27 @@ impl Global { None }; - let blas_descriptors = - blas_storage - .iter() - .map(|&(blas, ref entries, ref scratch_buffer_offset)| { - if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate { - log::info!("only rebuild implemented") - } - hal::BuildAccelerationStructureDescriptor { - entries, - mode: hal::AccelerationStructureBuildMode::Build, - flags: blas.flags, - source_acceleration_structure: None, - destination_acceleration_structure: blas.raw.as_ref().unwrap(), - scratch_buffer: &scratch_buffer, - scratch_buffer_offset: *scratch_buffer_offset, - } - }); + let scratch_buffer = unsafe { + device + .raw() + .create_buffer(&hal::BufferDescriptor { + label: Some("(wgpu) scratch buffer"), + size: max(scratch_buffer_blas_size, scratch_buffer_tlas_size), + usage: hal::BufferUses::ACCELERATION_STRUCTURE_SCRATCH | BufferUses::MAP_WRITE, + memory_flags: hal::MemoryFlags::empty(), + }) + .map_err(crate::device::DeviceError::from)? + }; + + let scratch_buffer_barrier = hal::BufferBarrier:: { + buffer: &scratch_buffer, + usage: BufferUses::ACCELERATION_STRUCTURE_SCRATCH + ..BufferUses::ACCELERATION_STRUCTURE_SCRATCH, + }; + + let blas_descriptors = blas_storage + .iter() + .map(|storage| map_blas(storage, &scratch_buffer)); let tlas_descriptors = tlas_storage.iter().map( |&(tlas, ref entries, ref scratch_buffer_offset, ref _range)| { @@ -1343,12 +707,6 @@ impl Global { }, ); - let scratch_buffer_barrier = hal::BufferBarrier:: { - buffer: &scratch_buffer, - usage: BufferUses::ACCELERATION_STRUCTURE_SCRATCH - ..BufferUses::ACCELERATION_STRUCTURE_SCRATCH, - }; - let mut lock_vec = Vec::::Buffer>>>>::new(); for tlas in &tlas_storage { @@ -1372,45 +730,15 @@ impl Global { let cmd_buf_raw = cmd_buf_data.encoder.open()?; - unsafe { - cmd_buf_raw.transition_buffers(input_barriers.into_iter()); - } - - if blas_present { - unsafe { - cmd_buf_raw.place_acceleration_structure_barrier( - hal::AccelerationStructureBarrier { - usage: hal::AccelerationStructureUses::BUILD_INPUT - ..hal::AccelerationStructureUses::BUILD_OUTPUT, - }, - ); - - cmd_buf_raw - .build_acceleration_structures(blas_storage.len() as u32, blas_descriptors); - } - } - - if blas_present && tlas_present { - unsafe { - cmd_buf_raw.transition_buffers(iter::once(scratch_buffer_barrier)); - } - } - - let mut source_usage = hal::AccelerationStructureUses::empty(); - let mut destination_usage = hal::AccelerationStructureUses::empty(); - if blas_present { - source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT; - destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT - } - if tlas_present { - source_usage |= hal::AccelerationStructureUses::SHADER_INPUT; - destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT; - } - unsafe { - cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier { - usage: source_usage..destination_usage, - }); - } + build_blas( + cmd_buf_raw, + blas_present, + tlas_present, + input_barriers, + blas_storage.len() as u32, + blas_descriptors, + scratch_buffer_barrier, + ); if tlas_present { unsafe { @@ -1568,3 +896,390 @@ impl BakedCommands { Ok(()) } } + +///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation). +fn iter_blas<'a, A: HalApi>( + blas_iter: impl Iterator>, + cmd_buf_data: &mut CommandBufferMutable, + build_command_index: NonZeroU64, + buffer_guard: &RwLockReadGuard>>, + blas_guard: &RwLockReadGuard>>, + buf_storage: &mut BufferStorage<'a, A>, +) -> Result<(), BuildAccelerationStructureError> { + for entry in blas_iter { + let blas = blas_guard + .get(entry.blas_id) + .map_err(|_| BuildAccelerationStructureError::InvalidBlasId)?; + cmd_buf_data.trackers.blas_s.insert_single(blas.clone()); + + if blas.raw.is_none() { + return Err(BuildAccelerationStructureError::InvalidBlas( + blas.error_ident(), + )); + } + + cmd_buf_data.blas_actions.push(BlasAction { + blas: blas.clone(), + kind: crate::ray_tracing::BlasActionKind::Build(build_command_index), + }); + + match entry.geometries { + BlasGeometries::TriangleGeometries(triangle_geometries) => { + for (i, mesh) in triangle_geometries.enumerate() { + let size_desc = match &blas.sizes { + wgt::BlasGeometrySizeDescriptors::Triangles { desc } => desc, + }; + if i >= size_desc.len() { + return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes( + blas.error_ident(), + )); + } + let size_desc = &size_desc[i]; + + if size_desc.flags != mesh.size.flags + || size_desc.vertex_count < mesh.size.vertex_count + || size_desc.vertex_format != mesh.size.vertex_format + || size_desc.index_count.is_none() != mesh.size.index_count.is_none() + || (size_desc.index_count.is_none() + || size_desc.index_count.unwrap() < mesh.size.index_count.unwrap()) + || size_desc.index_format.is_none() != mesh.size.index_format.is_none() + || (size_desc.index_format.is_none() + || size_desc.index_format.unwrap() != mesh.size.index_format.unwrap()) + { + return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes( + blas.error_ident(), + )); + } + + if size_desc.index_count.is_some() && mesh.index_buffer.is_none() { + return Err(BuildAccelerationStructureError::MissingIndexBuffer( + blas.error_ident(), + )); + } + let vertex_buffer = match buffer_guard.get(mesh.vertex_buffer) { + Ok(buffer) => buffer, + Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId), + }; + let vertex_pending = cmd_buf_data.trackers.buffers.set_single( + vertex_buffer, + BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, + ); + let index_data = if let Some(index_id) = mesh.index_buffer { + let index_buffer = match buffer_guard.get(index_id) { + Ok(buffer) => buffer, + Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId), + }; + if mesh.index_buffer_offset.is_none() + || mesh.size.index_count.is_none() + || mesh.size.index_count.is_none() + { + return Err(BuildAccelerationStructureError::MissingAssociatedData( + index_buffer.error_ident(), + )); + } + let data = cmd_buf_data.trackers.buffers.set_single( + index_buffer, + hal::BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, + ); + Some((index_buffer.clone(), data)) + } else { + None + }; + let transform_data = if let Some(transform_id) = mesh.transform_buffer { + let transform_buffer = match buffer_guard.get(transform_id) { + Ok(buffer) => buffer, + Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId), + }; + if mesh.transform_buffer_offset.is_none() { + return Err(BuildAccelerationStructureError::MissingAssociatedData( + transform_buffer.error_ident(), + )); + } + let data = cmd_buf_data.trackers.buffers.set_single( + transform_buffer, + BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT, + ); + Some((transform_buffer.clone(), data)) + } else { + None + }; + buf_storage.push(( + vertex_buffer.clone(), + vertex_pending, + index_data, + transform_data, + mesh, + None, + )); + } + + if let Some(last) = buf_storage.last_mut() { + last.5 = Some(blas.clone()); + } + } + } + } + Ok(()) +} + +/// Iterates over the buffers generated [iter_blas] and convert the barriers into hal barriers, and the triangles into hal [AccelerationStructureEntries] (and also some validation). +fn iter_buffers<'a, 'b, A: HalApi>( + buf_storage: &'a mut BufferStorage<'b, A>, + snatch_guard: &'a SnatchGuard, + input_barriers: &mut Vec>, + cmd_buf_data: &mut CommandBufferMutable, + buffer_guard: &RwLockReadGuard>>, + scratch_buffer_blas_size: &mut u64, + blas_storage: &mut Vec<(Arc>, hal::AccelerationStructureEntries<'a, A>, u64)>, +) -> Result<(), BuildAccelerationStructureError> { + let mut triangle_entries = Vec::>::new(); + for buf in buf_storage { + let mesh = &buf.4; + let vertex_buffer = { + let vertex_buffer = buf.0.as_ref(); + let vertex_raw = vertex_buffer.raw.get(snatch_guard).ok_or( + BuildAccelerationStructureError::InvalidBuffer(vertex_buffer.error_ident()), + )?; + if !vertex_buffer.usage.contains(BufferUsages::BLAS_INPUT) { + return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag( + vertex_buffer.error_ident(), + )); + } + if let Some(barrier) = buf + .1 + .take() + .map(|pending| pending.into_hal(vertex_buffer, snatch_guard)) + { + input_barriers.push(barrier); + } + if vertex_buffer.size + < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride + { + return Err(BuildAccelerationStructureError::InsufficientBufferSize( + vertex_buffer.error_ident(), + vertex_buffer.size, + (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride, + )); + } + let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride; + cmd_buf_data.buffer_memory_init_actions.extend( + vertex_buffer.initialization_status.read().create_action( + buffer_guard.get(mesh.vertex_buffer).unwrap(), + vertex_buffer_offset + ..(vertex_buffer_offset + + mesh.size.vertex_count as u64 * mesh.vertex_stride), + MemoryInitKind::NeedsInitializedMemory, + ), + ); + vertex_raw + }; + let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) = buf.2 { + let index_raw = index_buffer.raw.get(snatch_guard).ok_or( + BuildAccelerationStructureError::InvalidBuffer(index_buffer.error_ident()), + )?; + if !index_buffer.usage.contains(BufferUsages::BLAS_INPUT) { + return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag( + index_buffer.error_ident(), + )); + } + if let Some(barrier) = index_pending + .take() + .map(|pending| pending.into_hal(index_buffer, snatch_guard)) + { + input_barriers.push(barrier); + } + let index_stride = match mesh.size.index_format.unwrap() { + wgt::IndexFormat::Uint16 => 2, + wgt::IndexFormat::Uint32 => 4, + }; + if mesh.index_buffer_offset.unwrap() % index_stride != 0 { + return Err(BuildAccelerationStructureError::UnalignedIndexBufferOffset( + index_buffer.error_ident(), + )); + } + let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride; + + if mesh.size.index_count.unwrap() % 3 != 0 { + return Err(BuildAccelerationStructureError::InvalidIndexCount( + index_buffer.error_ident(), + mesh.size.index_count.unwrap(), + )); + } + if index_buffer.size + < mesh.size.index_count.unwrap() as u64 * index_stride + + mesh.index_buffer_offset.unwrap() + { + return Err(BuildAccelerationStructureError::InsufficientBufferSize( + index_buffer.error_ident(), + index_buffer.size, + mesh.size.index_count.unwrap() as u64 * index_stride + + mesh.index_buffer_offset.unwrap(), + )); + } + + cmd_buf_data.buffer_memory_init_actions.extend( + index_buffer.initialization_status.read().create_action( + index_buffer, + mesh.index_buffer_offset.unwrap() + ..(mesh.index_buffer_offset.unwrap() + index_buffer_size), + MemoryInitKind::NeedsInitializedMemory, + ), + ); + Some(index_raw) + } else { + None + }; + let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) = + buf.3 + { + if mesh.transform_buffer_offset.is_none() { + return Err(BuildAccelerationStructureError::MissingAssociatedData( + transform_buffer.error_ident(), + )); + } + let transform_raw = transform_buffer.raw.get(snatch_guard).ok_or( + BuildAccelerationStructureError::InvalidBuffer(transform_buffer.error_ident()), + )?; + if !transform_buffer.usage.contains(BufferUsages::BLAS_INPUT) { + return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag( + transform_buffer.error_ident(), + )); + } + if let Some(barrier) = transform_pending + .take() + .map(|pending| pending.into_hal(transform_buffer, snatch_guard)) + { + input_barriers.push(barrier); + } + if mesh.transform_buffer_offset.unwrap() % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 { + return Err( + BuildAccelerationStructureError::UnalignedTransformBufferOffset( + transform_buffer.error_ident(), + ), + ); + } + if transform_buffer.size < 48 + mesh.transform_buffer_offset.unwrap() { + return Err(BuildAccelerationStructureError::InsufficientBufferSize( + transform_buffer.error_ident(), + transform_buffer.size, + 48 + mesh.transform_buffer_offset.unwrap(), + )); + } + cmd_buf_data.buffer_memory_init_actions.extend( + transform_buffer.initialization_status.read().create_action( + transform_buffer, + mesh.transform_buffer_offset.unwrap()..(mesh.index_buffer_offset.unwrap() + 48), + MemoryInitKind::NeedsInitializedMemory, + ), + ); + Some(transform_raw) + } else { + None + }; + + let triangles = hal::AccelerationStructureTriangles { + vertex_buffer: Some(vertex_buffer), + vertex_format: mesh.size.vertex_format, + first_vertex: mesh.first_vertex, + vertex_count: mesh.size.vertex_count, + vertex_stride: mesh.vertex_stride, + indices: index_buffer.map(|index_buffer| hal::AccelerationStructureTriangleIndices::< + A, + > { + format: mesh.size.index_format.unwrap(), + buffer: Some(index_buffer), + offset: mesh.index_buffer_offset.unwrap() as u32, + count: mesh.size.index_count.unwrap(), + }), + transform: transform_buffer.map(|transform_buffer| { + hal::AccelerationStructureTriangleTransform { + buffer: transform_buffer, + offset: mesh.transform_buffer_offset.unwrap() as u32, + } + }), + flags: mesh.size.flags, + }; + triangle_entries.push(triangles); + if let Some(blas) = buf.5.take() { + let scratch_buffer_offset = *scratch_buffer_blas_size; + *scratch_buffer_blas_size += align_to( + blas.size_info.build_scratch_size as u32, + SCRATCH_BUFFER_ALIGNMENT, + ) as u64; + + blas_storage.push(( + blas, + hal::AccelerationStructureEntries::Triangles(triangle_entries), + scratch_buffer_offset, + )); + triangle_entries = Vec::new(); + } + } + Ok(()) +} + +fn map_blas<'a, A: HalApi>( + storage: &'a (Arc>, hal::AccelerationStructureEntries, BufferAddress), + scratch_buffer: &'a ::Buffer, +) -> hal::BuildAccelerationStructureDescriptor<'a, A> { + let (blas, entries, scratch_buffer_offset) = storage; + if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate { + log::info!("only rebuild implemented") + } + hal::BuildAccelerationStructureDescriptor { + entries, + mode: hal::AccelerationStructureBuildMode::Build, + flags: blas.flags, + source_acceleration_structure: None, + destination_acceleration_structure: blas.raw.as_ref().unwrap(), + scratch_buffer, + scratch_buffer_offset: *scratch_buffer_offset, + } +} + +fn build_blas<'a, A: HalApi>( + cmd_buf_raw: &mut A::CommandEncoder, + blas_present: bool, + tlas_present: bool, + input_barriers: Vec>, + desc_len: u32, + blas_descriptors: impl Iterator>, + scratch_buffer_barrier: hal::BufferBarrier, +) { + unsafe { + cmd_buf_raw.transition_buffers(input_barriers.into_iter()); + } + + if blas_present { + unsafe { + cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier { + usage: hal::AccelerationStructureUses::BUILD_INPUT + ..hal::AccelerationStructureUses::BUILD_OUTPUT, + }); + + cmd_buf_raw.build_acceleration_structures(desc_len, blas_descriptors); + } + } + + if blas_present && tlas_present { + unsafe { + cmd_buf_raw.transition_buffers(iter::once(scratch_buffer_barrier)); + } + } + + let mut source_usage = hal::AccelerationStructureUses::empty(); + let mut destination_usage = hal::AccelerationStructureUses::empty(); + if blas_present { + source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT; + destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT + } + if tlas_present { + source_usage |= hal::AccelerationStructureUses::SHADER_INPUT; + destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT; + } + unsafe { + cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier { + usage: source_usage..destination_usage, + }); + } +}