Skip to content

Commit

Permalink
More acceleration structure operations (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
attackgoat authored Nov 29, 2024
1 parent bdc4a81 commit df3b6b8
Show file tree
Hide file tree
Showing 15 changed files with 1,276 additions and 641 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ash-window = "0.13"
derive_builder = "0.20"
gpu-allocator = "0.27"
log = "0.4"
ordered-float = "4.1"
ordered-float = "4.2"
parking_lot = { version = "0.12", optional = true }
paste = "1.0"
profiling = "1.0"
Expand All @@ -39,10 +39,10 @@ ash-molten = "0.19"
[dev-dependencies]
anyhow = "1.0"
bmfont = { version = "0.3", default-features = false }
bytemuck = "1.14"
bytemuck = "1.16"
clap = { version = "4.5", features = ["derive"] }
glam = { version = "0.27", features = ["bytemuck"] }
half = { version = "2.3", features = ["bytemuck"] }
glam = { version = "0.28", features = ["bytemuck"] }
half = { version = "2.4", features = ["bytemuck"] }
hassle-rs = "0.11"
image = "0.25"
inline-spirv = "0.2"
Expand Down
8 changes: 2 additions & 6 deletions contrib/screen-13-fx/src/image_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ use {
#[cfg(debug_assertions)]
use log::warn;

fn align_up_u32(val: u32, atom: u32) -> u32 {
(val + atom - 1) & !(atom - 1)
}

/// Describes the channels and pixel stride of an image format
#[derive(Clone, Copy, Debug)]
pub enum ImageFormat {
Expand Down Expand Up @@ -133,7 +129,7 @@ impl ImageLoader {
);

#[cfg(debug_assertions)]
if pixels.len() > align_up_u32(format.stride() as u32 * width * height, 4) as usize {
if pixels.len() > (format.stride() as u32 * width * height).next_multiple_of(4) as usize {
warn!("unused data");
}

Expand All @@ -156,7 +152,7 @@ impl ImageLoader {

//trace!("{bitmap_width}x{bitmap_height} Stride={bitmap_stride}");

let pixel_buf_stride = align_up_u32(stride, 12);
let pixel_buf_stride = stride.next_multiple_of(12);
let pixel_buf_len = (pixel_buf_stride * height) as vk::DeviceSize;

//trace!("pixel_buf_len={pixel_buf_len} pixel_buf_stride={pixel_buf_stride}");
Expand Down
90 changes: 35 additions & 55 deletions examples/ray_omni.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,30 +148,31 @@ fn create_blas(
device: &Arc<Device>,
models: &[&Model],
) -> Result<Arc<AccelerationStructure>, DriverError> {
let info = AccelerationStructureGeometryInfo {
ty: vk::AccelerationStructureTypeKHR::BOTTOM_LEVEL,
flags: vk::BuildAccelerationStructureFlagsKHR::PREFER_FAST_TRACE,
geometries: models
let info = AccelerationStructureGeometryInfo::blas(
models
.iter()
.map(|model| AccelerationStructureGeometry {
max_primitive_count: model.index_count / 3,
flags: vk::GeometryFlagsKHR::OPAQUE,
geometry: AccelerationStructureGeometryData::Triangles {
index_data: DeviceOrHostAddress::DeviceAddress(Buffer::device_address(
&model.index_buf,
)),
index_type: vk::IndexType::UINT32,
max_vertex: model.vertex_count,
transform_data: None,
vertex_data: DeviceOrHostAddress::DeviceAddress(Buffer::device_address(
&model.vertex_buf,
)),
vertex_format: vk::Format::R32G32B32_SFLOAT,
vertex_stride: 24,
},
.map(|model| {
(
AccelerationStructureGeometry {
max_primitive_count: model.index_count / 3,
flags: vk::GeometryFlagsKHR::OPAQUE,
geometry: AccelerationStructureGeometryData::triangles(
Buffer::device_address(&model.index_buf),
vk::IndexType::UINT32,
model.vertex_count,
None,
Buffer::device_address(&model.vertex_buf),
vk::Format::R32G32B32_SFLOAT,
24,
),
},
vk::AccelerationStructureBuildRangeInfoKHR::default()
.primitive_count(model.index_count / 3),
)
})
.collect(),
};
.collect::<Box<_>>(),
)
.flags(vk::BuildAccelerationStructureFlagsKHR::PREFER_FAST_TRACE);
let size = AccelerationStructure::size_of(device, &info);

let mut render_graph = RenderGraph::new();
Expand All @@ -196,15 +197,7 @@ fn create_blas(
.to_builder()
.alignment(accel_struct_scratch_offset_alignment),
)?);
let build_ranges = models
.iter()
.map(|model| vk::AccelerationStructureBuildRangeInfoKHR {
primitive_count: model.index_count / 3,
primitive_offset: 0,
first_vertex: 0,
transform_offset: 0,
})
.collect::<Box<_>>();
let scratch_data = render_graph.node_device_address(scratch_buf);

let mut pass = render_graph.begin_pass("Build BLAS");

Expand All @@ -219,7 +212,7 @@ fn create_blas(
pass.access_node(blas, AccessType::AccelerationStructureBuildWrite)
.access_node(scratch_buf, AccessType::AccelerationStructureBufferWrite)
.record_acceleration(move |accel, _| {
accel.build_structure(blas, scratch_buf, &info, &build_ranges);
accel.build_structure(&info, blas, scratch_data);
});

let blas = render_graph.unbind_node(blas);
Expand Down Expand Up @@ -358,18 +351,14 @@ fn create_tlas(
buffer
});

let info = AccelerationStructureGeometryInfo {
ty: vk::AccelerationStructureTypeKHR::TOP_LEVEL,
flags: vk::BuildAccelerationStructureFlagsKHR::PREFER_FAST_TRACE,
geometries: vec![AccelerationStructureGeometry {
max_primitive_count: 2,
flags: vk::GeometryFlagsKHR::OPAQUE,
geometry: AccelerationStructureGeometryData::Instances {
array_of_pointers: false,
data: DeviceOrHostAddress::DeviceAddress(Buffer::device_address(&instance_buf)),
},
}],
};
let info = AccelerationStructureGeometryInfo::tlas([(
AccelerationStructureGeometry::opaque(
2,
AccelerationStructureGeometryData::instances(Buffer::device_address(&instance_buf)),
),
vk::AccelerationStructureBuildRangeInfoKHR::default().primitive_count(1),
)])
.flags(vk::BuildAccelerationStructureFlagsKHR::PREFER_FAST_TRACE);
let size = AccelerationStructure::size_of(device, &info);
let tlas =
render_graph.bind_node(pool.lease(AccelerationStructureInfo::tlas(size.create_size))?);
Expand All @@ -391,6 +380,7 @@ fn create_tlas(
.alignment(accel_struct_scratch_offset_alignment),
)?,
);
let scratch_data = render_graph.node_device_address(scratch_buf);
let blas = render_graph.bind_node(blas);
let instance_buf = render_graph.bind_node(instance_buf);

Expand All @@ -401,17 +391,7 @@ fn create_tlas(
.access_node(scratch_buf, AccessType::AccelerationStructureBufferWrite)
.access_node(tlas, AccessType::AccelerationStructureBuildWrite)
.record_acceleration(move |accel, _| {
accel.build_structure(
tlas,
scratch_buf,
&info,
&[vk::AccelerationStructureBuildRangeInfoKHR {
first_vertex: 0,
primitive_count: 1,
primitive_offset: 0,
transform_offset: 0,
}],
)
accel.build_structure(&info, tlas, scratch_data);
});

Ok(tlas)
Expand Down
129 changes: 52 additions & 77 deletions examples/ray_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,6 @@ static SHADER_SHADOW_MISS: &[u32] = inline_spirv!(
)
.as_slice();

fn align_up(val: u32, atom: u32) -> u32 {
(val + atom - 1) & !(atom - 1)
}

fn create_ray_trace_pipeline(device: &Arc<Device>) -> Result<Arc<RayTracePipeline>, DriverError> {
Ok(Arc::new(RayTracePipeline::create(
device,
Expand Down Expand Up @@ -514,15 +510,18 @@ fn main() -> anyhow::Result<()> {
// Setup a shader binding table
// ------------------------------------------------------------------------------------------ //

let sbt_handle_size = align_up(shader_group_handle_size, shader_group_handle_alignment);
let sbt_rgen_size = sbt_handle_size;
let sbt_hit_size = sbt_handle_size;
let sbt_miss_size = 2 * sbt_handle_size;
let sbt_rgen_size = shader_group_handle_size;
let sbt_hit_start = sbt_rgen_size.next_multiple_of(shader_group_base_alignment);
let sbt_hit_size = shader_group_handle_size;
let sbt_miss_start =
(sbt_hit_start + sbt_hit_size).next_multiple_of(shader_group_base_alignment);
let sbt_miss_size =
2 * shader_group_handle_size.next_multiple_of(shader_group_handle_alignment);
let sbt_buf = Arc::new({
let mut buf = Buffer::create(
&window.device,
BufferInfo::host_mem(
(sbt_rgen_size + sbt_hit_size + sbt_miss_size) as _,
(sbt_miss_start + sbt_miss_size) as _,
vk::BufferUsageFlags::SHADER_BINDING_TABLE_KHR
| vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS,
)
Expand All @@ -531,36 +530,39 @@ fn main() -> anyhow::Result<()> {
)
.unwrap();

let mut data = Buffer::mapped_slice_mut(&mut buf);
data.fill(0);

let data = Buffer::mapped_slice_mut(&mut buf);
let rgen_handle = RayTracePipeline::group_handle(&ray_trace_pipeline, 0)?;
data[0..rgen_handle.len()].copy_from_slice(rgen_handle);
data = &mut data[sbt_rgen_size as _..];

// If hit/miss had different strides we would need to iterate each here
for idx in 1..4 {
let handle = RayTracePipeline::group_handle(&ray_trace_pipeline, idx)?;
data[0..handle.len()].copy_from_slice(handle);
data = &mut data[sbt_handle_size as _..];
}
let hit_handle = RayTracePipeline::group_handle(&ray_trace_pipeline, 1)?;
data[sbt_hit_start as usize..sbt_hit_start as usize + hit_handle.len()]
.copy_from_slice(hit_handle);

let miss_handle = RayTracePipeline::group_handle(&ray_trace_pipeline, 2)?;
data[sbt_miss_start as usize..sbt_miss_start as usize + miss_handle.len()]
.copy_from_slice(miss_handle);
let miss_shadow_handle = RayTracePipeline::group_handle(&ray_trace_pipeline, 3)?;
let sbt_miss_shadow_start = sbt_miss_start + shader_group_handle_alignment;
data[sbt_miss_shadow_start as usize
..sbt_miss_shadow_start as usize + miss_shadow_handle.len()]
.copy_from_slice(miss_shadow_handle);

buf
});
let sbt_address = Buffer::device_address(&sbt_buf);
let sbt_rgen = vk::StridedDeviceAddressRegionKHR {
device_address: sbt_address,
stride: sbt_rgen_size as _,
stride: shader_group_handle_size as _,
size: sbt_rgen_size as _,
};
let sbt_hit = vk::StridedDeviceAddressRegionKHR {
device_address: sbt_rgen.device_address + sbt_rgen_size as vk::DeviceAddress,
stride: sbt_handle_size as _,
device_address: sbt_address + sbt_hit_start as vk::DeviceAddress,
stride: shader_group_handle_size as _,
size: sbt_hit_size as _,
};
let sbt_miss = vk::StridedDeviceAddressRegionKHR {
device_address: sbt_hit.device_address + sbt_hit_size as vk::DeviceAddress,
stride: sbt_handle_size as _,
device_address: sbt_address + sbt_miss_start as vk::DeviceAddress,
stride: shader_group_handle_size as _,
size: sbt_miss_size as _,
};
let sbt_callable = vk::StridedDeviceAddressRegionKHR::default();
Expand All @@ -576,25 +578,21 @@ fn main() -> anyhow::Result<()> {
// Create the bottom level acceleration structure
// ------------------------------------------------------------------------------------------ //

let blas_geometry_info = AccelerationStructureGeometryInfo {
ty: vk::AccelerationStructureTypeKHR::BOTTOM_LEVEL,
flags: vk::BuildAccelerationStructureFlagsKHR::empty(),
geometries: vec![AccelerationStructureGeometry {
max_primitive_count: triangle_count,
flags: vk::GeometryFlagsKHR::OPAQUE,
geometry: AccelerationStructureGeometryData::Triangles {
index_data: DeviceOrHostAddress::DeviceAddress(Buffer::device_address(&index_buf)),
index_type: vk::IndexType::UINT32,
max_vertex: vertex_count,
transform_data: None,
vertex_data: DeviceOrHostAddress::DeviceAddress(Buffer::device_address(
&vertex_buf,
)),
vertex_format: vk::Format::R32G32B32_SFLOAT,
vertex_stride: 12,
},
}],
};
let blas_geometry_info = AccelerationStructureGeometryInfo::blas([(
AccelerationStructureGeometry::opaque(
triangle_count,
AccelerationStructureGeometryData::triangles(
Buffer::device_address(&index_buf),
vk::IndexType::UINT32,
vertex_count,
None,
Buffer::device_address(&vertex_buf),
vk::Format::R32G32B32_SFLOAT,
12,
),
),
vk::AccelerationStructureBuildRangeInfoKHR::default().primitive_count(triangle_count),
)]);
let blas_size = AccelerationStructure::size_of(&window.device, &blas_geometry_info);
let blas = Arc::new(AccelerationStructure::create(
&window.device,
Expand Down Expand Up @@ -642,18 +640,13 @@ fn main() -> anyhow::Result<()> {
// Create the top level acceleration structure
// ------------------------------------------------------------------------------------------ //

let tlas_geometry_info = AccelerationStructureGeometryInfo {
ty: vk::AccelerationStructureTypeKHR::TOP_LEVEL,
flags: vk::BuildAccelerationStructureFlagsKHR::empty(),
geometries: vec![AccelerationStructureGeometry {
max_primitive_count: 1,
flags: vk::GeometryFlagsKHR::OPAQUE,
geometry: AccelerationStructureGeometryData::Instances {
array_of_pointers: false,
data: DeviceOrHostAddress::DeviceAddress(Buffer::device_address(&instance_buf)),
},
}],
};
let tlas_geometry_info = AccelerationStructureGeometryInfo::tlas([(
AccelerationStructureGeometry::opaque(
1,
AccelerationStructureGeometryData::instances(Buffer::device_address(&instance_buf)),
),
vk::AccelerationStructureBuildRangeInfoKHR::default().primitive_count(1),
)]);
let tlas_size = AccelerationStructure::size_of(&window.device, &tlas_geometry_info);
let tlas = Arc::new(AccelerationStructure::create(
&window.device,
Expand Down Expand Up @@ -689,6 +682,7 @@ fn main() -> anyhow::Result<()> {
.to_builder()
.alignment(accel_struct_scratch_offset_alignment),
)?);
let scratch_data = render_graph.node_device_address(scratch_buf);

render_graph
.begin_pass("Build BLAS")
Expand All @@ -697,17 +691,7 @@ fn main() -> anyhow::Result<()> {
.access_node(scratch_buf, AccessType::AccelerationStructureBufferWrite)
.access_node(blas_node, AccessType::AccelerationStructureBuildWrite)
.record_acceleration(move |accel, _| {
accel.build_structure(
blas_node,
scratch_buf,
&blas_geometry_info,
&[vk::AccelerationStructureBuildRangeInfoKHR {
first_vertex: 0,
primitive_count: triangle_count,
primitive_offset: 0,
transform_offset: 0,
}],
)
accel.build_structure(&blas_geometry_info, blas_node, scratch_data);
});
}

Expand All @@ -722,6 +706,7 @@ fn main() -> anyhow::Result<()> {
.to_builder()
.alignment(accel_struct_scratch_offset_alignment),
)?);
let scratch_data = render_graph.node_device_address(scratch_buf);
let instance_node = render_graph.bind_node(&instance_buf);
let tlas_node = render_graph.bind_node(&tlas);

Expand All @@ -732,17 +717,7 @@ fn main() -> anyhow::Result<()> {
.access_node(scratch_buf, AccessType::AccelerationStructureBufferWrite)
.access_node(tlas_node, AccessType::AccelerationStructureBuildWrite)
.record_acceleration(move |accel, _| {
accel.build_structure(
tlas_node,
scratch_buf,
&tlas_geometry_info,
&[vk::AccelerationStructureBuildRangeInfoKHR {
first_vertex: 0,
primitive_count: 1,
primitive_offset: 0,
transform_offset: 0,
}],
);
accel.build_structure(&tlas_geometry_info, tlas_node, scratch_data);
});
}

Expand Down
Loading

0 comments on commit df3b6b8

Please sign in to comment.