Skip to content

Commit

Permalink
spv-out: implement OpArrayLength on array buffer bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark authored and teoxoy committed Apr 2, 2024
1 parent 1fd47b5 commit bfe0b90
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 134 deletions.
91 changes: 74 additions & 17 deletions naga/src/back/spv/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ Bounds-checking for SPIR-V output.
*/

use super::{
helpers::global_needs_wrapper, selection::Selection, Block, BlockContext, Error, IdGenerator,
Instruction, Word,
helpers::{global_needs_wrapper, map_storage_class},
selection::Selection,
Block, BlockContext, Error, IdGenerator, Instruction, Word,
};
use crate::{arena::Handle, proc::BoundsCheckPolicy};

Expand Down Expand Up @@ -42,32 +43,88 @@ impl<'w> BlockContext<'w> {
array: Handle<crate::Expression>,
block: &mut Block,
) -> Result<Word, Error> {
// Naga IR permits runtime-sized arrays as global variables or as the
// final member of a struct that is a global variable. SPIR-V permits
// only the latter, so this back end wraps bare runtime-sized arrays
// in a made-up struct; see `helpers::global_needs_wrapper` and its uses.
// This code must handle both cases.
let (structure_id, last_member_index) = match self.ir_function.expressions[array] {
// Naga IR permits runtime-sized arrays as global variables, or as the
// final member of a struct that is a global variable, or one of these
// inside a buffer that is itself an element in a buffer bindings array.
// SPIR-V requires that runtime-sized arrays are wrapped in structs.
// See `helpers::global_needs_wrapper` and its uses.
let (opt_array_index, global_handle, opt_last_member_index) = match self
.ir_function
.expressions[array]
{
// Note that SPIR-V forbids `OpArrayLength` on a variable pointer,
// so we aren't handling `crate::Expression::Access` here.
crate::Expression::AccessIndex { base, index } => {
match self.ir_function.expressions[base] {
crate::Expression::GlobalVariable(handle) => (
self.writer.global_variables[handle.index()].access_id,
index,
),
_ => return Err(Error::Validation("array length expression")),
// The global variable is an array of buffer bindings of structs,
// and we are accessing the last member.
crate::Expression::AccessIndex {
base: base_outer,
index: index_outer,
} => match self.ir_function.expressions[base_outer] {
crate::Expression::GlobalVariable(handle) => {
(Some(index_outer), handle, Some(index))
}
_ => return Err(Error::Validation("array length expression case-1a")),
},
crate::Expression::GlobalVariable(handle) => {
let global = &self.ir_module.global_variables[handle];
match self.ir_module.types[global.ty].inner {
// The global variable is an array of buffer bindings of run-time arrays.
crate::TypeInner::BindingArray { .. } => (Some(index), handle, None),
// The global variable is a struct, and we are accessing the last member
_ => (None, handle, Some(index)),
}
}
_ => return Err(Error::Validation("array length expression case-1c")),
}
}
// The global variable is a run-time array.
crate::Expression::GlobalVariable(handle) => {
let global = &self.ir_module.global_variables[handle];
if !global_needs_wrapper(self.ir_module, global) {
return Err(Error::Validation("array length expression"));
return Err(Error::Validation("array length expression case-2"));
}

(self.writer.global_variables[handle.index()].var_id, 0)
(None, handle, None)
}
_ => return Err(Error::Validation("array length expression")),
_ => return Err(Error::Validation("array length expression case-3")),
};

let gvar = self.writer.global_variables[global_handle.index()].clone();
let global = &self.ir_module.global_variables[global_handle];
let (last_member_index, gvar_id) = match opt_last_member_index {
Some(index) => (index, gvar.access_id),
None => {
if !global_needs_wrapper(self.ir_module, global) {
return Err(Error::Validation(
"pointer to a global that is not a wrapped array",
));
}
(0, gvar.var_id)
}
};
let structure_id = match opt_array_index {
// We are indexing inside a binding array, generate the access op.
Some(index) => {
let element_type_id = match self.ir_module.types[global.ty].inner {
crate::TypeInner::BindingArray { base, size: _ } => {
let class = map_storage_class(global.space);
self.get_pointer_id(base, class)?
}
_ => return Err(Error::Validation("array length expression case-4")),
};
let index_id = self.get_index_constant(index);
let structure_id = self.gen_id();
block.body.push(Instruction::access_chain(
element_type_id,
structure_id,
gvar_id,
&[index_id],
));
structure_id
}
None => gvar_id,
};
let length_id = self.gen_id();
block.body.push(Instruction::array_length(
self.writer.get_uint_type_id(),
Expand Down
9 changes: 9 additions & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,15 @@ impl BlockContext<'_> {
self.writer
.get_constant_scalar(crate::Literal::I32(scope as _))
}

fn get_pointer_id(
&mut self,
handle: Handle<crate::Type>,
class: spirv::StorageClass,
) -> Result<Word, Error> {
self.writer
.get_pointer_id(&self.ir_module.types, handle, class)
}
}

#[derive(Clone, Copy, Default)]
Expand Down
62 changes: 32 additions & 30 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,36 +565,38 @@ impl Writer {
// Handle globals are pre-emitted and should be loaded automatically.
//
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
let is_binding_array = match ir_module.types[var.ty].inner {
crate::TypeInner::BindingArray { .. } => true,
_ => false,
};

if var.space == crate::AddressSpace::Handle && !is_binding_array {
let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
let id = self.id_gen.next();
prelude
.body
.push(Instruction::load(var_type_id, id, gv.var_id, None));
gv.access_id = gv.var_id;
gv.handle_id = id;
} else if global_needs_wrapper(ir_module, var) {
let class = map_storage_class(var.space);
let pointer_type_id = self.get_pointer_id(&ir_module.types, var.ty, class)?;
let index_id = self.get_index_constant(0);

let id = self.id_gen.next();
prelude.body.push(Instruction::access_chain(
pointer_type_id,
id,
gv.var_id,
&[index_id],
));
gv.access_id = id;
} else {
// by default, the variable ID is accessed as is
gv.access_id = gv.var_id;
};
match ir_module.types[var.ty].inner {
crate::TypeInner::BindingArray { .. } => {
gv.access_id = gv.var_id;
}
_ => {
if var.space == crate::AddressSpace::Handle {
let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
let id = self.id_gen.next();
prelude
.body
.push(Instruction::load(var_type_id, id, gv.var_id, None));
gv.access_id = gv.var_id;
gv.handle_id = id;
} else if global_needs_wrapper(ir_module, var) {
let class = map_storage_class(var.space);
let pointer_type_id =
self.get_pointer_id(&ir_module.types, var.ty, class)?;
let index_id = self.get_index_constant(0);
let id = self.id_gen.next();
prelude.body.push(Instruction::access_chain(
pointer_type_id,
id,
gv.var_id,
&[index_id],
));
gv.access_id = id;
} else {
// by default, the variable ID is accessed as is
gv.access_id = gv.var_id;
};
}
}

// work around borrow checking in the presence of `self.xxx()` calls
self.global_variables[handle.index()] = gv;
Expand Down
4 changes: 3 additions & 1 deletion naga/tests/in/binding-buffer-arrays.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ struct UniformIndex {
index: u32
}

struct Foo { x: u32 }
struct Foo { x: u32, far: array<i32> }
@group(0) @binding(0)
var<storage, read> storage_array: binding_array<Foo, 1>;
@group(0) @binding(10)
Expand All @@ -23,5 +23,7 @@ fn main(fragment_in: FragmentIn) -> @location(0) u32 {
u1 += storage_array[uniform_index].x;
u1 += storage_array[non_uniform_index].x;

u1 += arrayLength(&storage_array[0].far);

return u1;
}
Loading

0 comments on commit bfe0b90

Please sign in to comment.