Skip to content

Commit

Permalink
Convert Naga's MSL Backend to Generating Argument Buffers for Binding…
Browse files Browse the repository at this point in the history
… Arrays
  • Loading branch information
cwfitzgerald committed Dec 16, 2024
1 parent 18a17ad commit d439c66
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 74 deletions.
1 change: 1 addition & 0 deletions naga/src/back/msl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,4 +341,5 @@ pub const RESERVED: &[&str] = &[
"DefaultConstructible",
super::writer::FREXP_FUNCTION,
super::writer::MODF_FUNCTION,
super::writer::ARGUMENT_BUFFER_WRAPPER_STRUCT,
];
2 changes: 0 additions & 2 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ pub struct BindTarget {
pub buffer: Option<Slot>,
pub texture: Option<Slot>,
pub sampler: Option<BindSamplerTarget>,
/// If the binding is an unsized binding array, this overrides the size.
pub binding_array_size: Option<u32>,
pub mutable: bool,
}

Expand Down
53 changes: 34 additions & 19 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
/// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument.
/// However, if you put that texture inside a struct, everything is totally fine. This
/// baffles me to no end.
///
/// As such, we wrap all argument buffers in a struct that has a single generic <T> field.
/// This allows `NagaArgumentBufferWrapper<metal::texture<..>>*` to work. The astute among
/// you have noticed that this should be exactly the same to the compiler, and you're correct.
pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper";

/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
///
Expand Down Expand Up @@ -275,24 +283,17 @@ impl Display for TypeContext<'_> {
crate::TypeInner::RayQuery => {
write!(out, "{RAY_QUERY_TYPE}")
}
crate::TypeInner::BindingArray { base, size } => {
crate::TypeInner::BindingArray { base, .. } => {
let base_tyname = Self {
handle: base,
first_time: false,
..*self
};

if let Some(&super::ResolvedBinding::Resource(super::BindTarget {
binding_array_size: Some(override_size),
..
})) = self.binding
{
write!(out, "{NAMESPACE}::array<{base_tyname}, {override_size}>")
} else if let crate::ArraySize::Constant(size) = size {
write!(out, "{NAMESPACE}::array<{base_tyname}, {size}>")
} else {
unreachable!("metal requires all arrays be constant sized");
}
write!(
out,
"constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*"
)
}
}
}
Expand Down Expand Up @@ -2549,6 +2550,8 @@ impl<W: Write> Writer<W> {
} => true,
_ => false,
};
let accessing_wrapped_binding_array =
matches!(*base_ty, crate::TypeInner::BindingArray { .. });

self.put_access_chain(base, policy, context)?;
if accessing_wrapped_array {
Expand Down Expand Up @@ -2585,6 +2588,10 @@ impl<W: Write> Writer<W> {

write!(self.out, "]")?;

if accessing_wrapped_binding_array {
write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
}

Ok(())
}

Expand Down Expand Up @@ -3696,7 +3703,18 @@ impl<W: Write> Writer<W> {
}

fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
let mut generated_argument_buffer_wrapper = false;
for (handle, ty) in module.types.iter() {
if let crate::TypeInner::BindingArray { .. } = ty.inner {
if !generated_argument_buffer_wrapper {
writeln!(self.out, "template <typename T>")?;
writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?;
writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?;
writeln!(self.out, "}};")?;
generated_argument_buffer_wrapper = true;
}
}

if !ty.needs_alias() {
continue;
}
Expand Down Expand Up @@ -4995,13 +5013,10 @@ template <typename A>
let target = options.get_resource_binding_target(ep, br);
let good = match target {
Some(target) => {
let binding_ty = match module.types[var.ty].inner {
crate::TypeInner::BindingArray { base, .. } => {
&module.types[base].inner
}
ref ty => ty,
};
match *binding_ty {
// We intentionally don't dereference binding_arrays here,
// so that binding arrays fall to the buffer location.

match module.types[var.ty].inner {
crate::TypeInner::Image { .. } => target.texture.is_some(),
crate::TypeInner::Sampler { .. } => {
target.sampler.is_some()
Expand Down
4 changes: 2 additions & 2 deletions naga/tests/in/binding-arrays.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
restrict_indexing: true
),
msl: (
lang_version: (2, 0),
lang_version: (3, 0),
per_entry_point_map: {
"main": (
resources: {
(group: 0, binding: 0): (texture: Some(0), binding_array_size: Some(10), mutable: false),
(group: 0, binding: 0): (buffer: Some(0), binding_array_size: Some(10), mutable: false),
},
sizes_buffer: None,
)
Expand Down
106 changes: 55 additions & 51 deletions naga/tests/out/msl/binding-arrays.msl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// language: metal2.0
// language: metal3.0
#include <metal_stdlib>
#include <simd/simd.h>

Expand All @@ -13,6 +13,10 @@ struct DefaultConstructible {
struct UniformIndex {
uint index;
};
template <typename T>
struct NagaArgumentBufferWrapper {
T inner;
};
struct FragmentIn {
uint index;
};
Expand All @@ -25,14 +29,14 @@ struct main_Output {
};
fragment main_Output main_(
main_Input varyings [[stage_in]]
, metal::array<metal::texture2d<float, metal::access::sample>, 10> texture_array_unbounded [[texture(0)]]
, metal::array<metal::texture2d<float, metal::access::sample>, 5> texture_array_bounded [[user(fake0)]]
, metal::array<metal::texture2d_array<float, metal::access::sample>, 5> texture_array_2darray [[user(fake0)]]
, metal::array<metal::texture2d_ms<float, metal::access::read>, 5> texture_array_multisampled [[user(fake0)]]
, metal::array<metal::depth2d<float, metal::access::sample>, 5> texture_array_depth [[user(fake0)]]
, metal::array<metal::texture2d<float, metal::access::write>, 5> texture_array_storage [[user(fake0)]]
, metal::array<metal::sampler, 5> samp [[user(fake0)]]
, metal::array<metal::sampler, 5> samp_comp [[user(fake0)]]
, constant NagaArgumentBufferWrapper<metal::texture2d<float, metal::access::sample>>* texture_array_unbounded [[buffer(0)]]
, constant NagaArgumentBufferWrapper<metal::texture2d<float, metal::access::sample>>* texture_array_bounded [[user(fake0)]]
, constant NagaArgumentBufferWrapper<metal::texture2d_array<float, metal::access::sample>>* texture_array_2darray [[user(fake0)]]
, constant NagaArgumentBufferWrapper<metal::texture2d_ms<float, metal::access::read>>* texture_array_multisampled [[user(fake0)]]
, constant NagaArgumentBufferWrapper<metal::depth2d<float, metal::access::sample>>* texture_array_depth [[user(fake0)]]
, constant NagaArgumentBufferWrapper<metal::texture2d<float, metal::access::write>>* texture_array_storage [[user(fake0)]]
, constant NagaArgumentBufferWrapper<metal::sampler>* samp [[user(fake0)]]
, constant NagaArgumentBufferWrapper<metal::sampler>* samp_comp [[user(fake0)]]
, constant UniformIndex& uni [[user(fake0)]]
) {
const FragmentIn fragment_in = { varyings.index };
Expand All @@ -45,116 +49,116 @@ fragment main_Output main_(
metal::float2 uv = metal::float2(0.0);
metal::int2 pix = metal::int2(0);
metal::uint2 _e22 = u2_;
u2_ = _e22 + metal::uint2(texture_array_unbounded[0].get_width(), texture_array_unbounded[0].get_height());
u2_ = _e22 + metal::uint2(texture_array_unbounded[0].inner.get_width(), texture_array_unbounded[0].inner.get_height());
metal::uint2 _e27 = u2_;
u2_ = _e27 + metal::uint2(texture_array_unbounded[uniform_index].get_width(), texture_array_unbounded[uniform_index].get_height());
u2_ = _e27 + metal::uint2(texture_array_unbounded[uniform_index].inner.get_width(), texture_array_unbounded[uniform_index].inner.get_height());
metal::uint2 _e32 = u2_;
u2_ = _e32 + metal::uint2(texture_array_unbounded[non_uniform_index].get_width(), texture_array_unbounded[non_uniform_index].get_height());
metal::float4 _e38 = texture_array_bounded[0].gather(samp[0], uv);
u2_ = _e32 + metal::uint2(texture_array_unbounded[non_uniform_index].inner.get_width(), texture_array_unbounded[non_uniform_index].inner.get_height());
metal::float4 _e38 = texture_array_bounded[0].inner.gather(samp[0].inner, uv);
metal::float4 _e39 = v4_;
v4_ = _e39 + _e38;
metal::float4 _e45 = texture_array_bounded[uniform_index].gather(samp[uniform_index], uv);
metal::float4 _e45 = texture_array_bounded[uniform_index].inner.gather(samp[uniform_index].inner, uv);
metal::float4 _e46 = v4_;
v4_ = _e46 + _e45;
metal::float4 _e52 = texture_array_bounded[non_uniform_index].gather(samp[non_uniform_index], uv);
metal::float4 _e52 = texture_array_bounded[non_uniform_index].inner.gather(samp[non_uniform_index].inner, uv);
metal::float4 _e53 = v4_;
v4_ = _e53 + _e52;
metal::float4 _e60 = texture_array_depth[0].gather_compare(samp_comp[0], uv, 0.0);
metal::float4 _e60 = texture_array_depth[0].inner.gather_compare(samp_comp[0].inner, uv, 0.0);
metal::float4 _e61 = v4_;
v4_ = _e61 + _e60;
metal::float4 _e68 = texture_array_depth[uniform_index].gather_compare(samp_comp[uniform_index], uv, 0.0);
metal::float4 _e68 = texture_array_depth[uniform_index].inner.gather_compare(samp_comp[uniform_index].inner, uv, 0.0);
metal::float4 _e69 = v4_;
v4_ = _e69 + _e68;
metal::float4 _e76 = texture_array_depth[non_uniform_index].gather_compare(samp_comp[non_uniform_index], uv, 0.0);
metal::float4 _e76 = texture_array_depth[non_uniform_index].inner.gather_compare(samp_comp[non_uniform_index].inner, uv, 0.0);
metal::float4 _e77 = v4_;
v4_ = _e77 + _e76;
metal::float4 _e82 = (uint(0) < texture_array_unbounded[0].get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[0].get_width(0), texture_array_unbounded[0].get_height(0))) ? texture_array_unbounded[0].read(metal::uint2(pix), 0): DefaultConstructible());
metal::float4 _e82 = (uint(0) < texture_array_unbounded[0].inner.get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[0].inner.get_width(0), texture_array_unbounded[0].inner.get_height(0))) ? texture_array_unbounded[0].inner.read(metal::uint2(pix), 0): DefaultConstructible());
metal::float4 _e83 = v4_;
v4_ = _e83 + _e82;
metal::float4 _e88 = (uint(0) < texture_array_unbounded[uniform_index].get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[uniform_index].get_width(0), texture_array_unbounded[uniform_index].get_height(0))) ? texture_array_unbounded[uniform_index].read(metal::uint2(pix), 0): DefaultConstructible());
metal::float4 _e88 = (uint(0) < texture_array_unbounded[uniform_index].inner.get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[uniform_index].inner.get_width(0), texture_array_unbounded[uniform_index].inner.get_height(0))) ? texture_array_unbounded[uniform_index].inner.read(metal::uint2(pix), 0): DefaultConstructible());
metal::float4 _e89 = v4_;
v4_ = _e89 + _e88;
metal::float4 _e94 = (uint(0) < texture_array_unbounded[non_uniform_index].get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[non_uniform_index].get_width(0), texture_array_unbounded[non_uniform_index].get_height(0))) ? texture_array_unbounded[non_uniform_index].read(metal::uint2(pix), 0): DefaultConstructible());
metal::float4 _e94 = (uint(0) < texture_array_unbounded[non_uniform_index].inner.get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[non_uniform_index].inner.get_width(0), texture_array_unbounded[non_uniform_index].inner.get_height(0))) ? texture_array_unbounded[non_uniform_index].inner.read(metal::uint2(pix), 0): DefaultConstructible());
metal::float4 _e95 = v4_;
v4_ = _e95 + _e94;
uint _e100 = u1_;
u1_ = _e100 + texture_array_2darray[0].get_array_size();
u1_ = _e100 + texture_array_2darray[0].inner.get_array_size();
uint _e105 = u1_;
u1_ = _e105 + texture_array_2darray[uniform_index].get_array_size();
u1_ = _e105 + texture_array_2darray[uniform_index].inner.get_array_size();
uint _e110 = u1_;
u1_ = _e110 + texture_array_2darray[non_uniform_index].get_array_size();
u1_ = _e110 + texture_array_2darray[non_uniform_index].inner.get_array_size();
uint _e115 = u1_;
u1_ = _e115 + texture_array_bounded[0].get_num_mip_levels();
u1_ = _e115 + texture_array_bounded[0].inner.get_num_mip_levels();
uint _e120 = u1_;
u1_ = _e120 + texture_array_bounded[uniform_index].get_num_mip_levels();
u1_ = _e120 + texture_array_bounded[uniform_index].inner.get_num_mip_levels();
uint _e125 = u1_;
u1_ = _e125 + texture_array_bounded[non_uniform_index].get_num_mip_levels();
u1_ = _e125 + texture_array_bounded[non_uniform_index].inner.get_num_mip_levels();
uint _e130 = u1_;
u1_ = _e130 + texture_array_multisampled[0].get_num_samples();
u1_ = _e130 + texture_array_multisampled[0].inner.get_num_samples();
uint _e135 = u1_;
u1_ = _e135 + texture_array_multisampled[uniform_index].get_num_samples();
u1_ = _e135 + texture_array_multisampled[uniform_index].inner.get_num_samples();
uint _e140 = u1_;
u1_ = _e140 + texture_array_multisampled[non_uniform_index].get_num_samples();
metal::float4 _e146 = texture_array_bounded[0].sample(samp[0], uv);
u1_ = _e140 + texture_array_multisampled[non_uniform_index].inner.get_num_samples();
metal::float4 _e146 = texture_array_bounded[0].inner.sample(samp[0].inner, uv);
metal::float4 _e147 = v4_;
v4_ = _e147 + _e146;
metal::float4 _e153 = texture_array_bounded[uniform_index].sample(samp[uniform_index], uv);
metal::float4 _e153 = texture_array_bounded[uniform_index].inner.sample(samp[uniform_index].inner, uv);
metal::float4 _e154 = v4_;
v4_ = _e154 + _e153;
metal::float4 _e160 = texture_array_bounded[non_uniform_index].sample(samp[non_uniform_index], uv);
metal::float4 _e160 = texture_array_bounded[non_uniform_index].inner.sample(samp[non_uniform_index].inner, uv);
metal::float4 _e161 = v4_;
v4_ = _e161 + _e160;
metal::float4 _e168 = texture_array_bounded[0].sample(samp[0], uv, metal::bias(0.0));
metal::float4 _e168 = texture_array_bounded[0].inner.sample(samp[0].inner, uv, metal::bias(0.0));
metal::float4 _e169 = v4_;
v4_ = _e169 + _e168;
metal::float4 _e176 = texture_array_bounded[uniform_index].sample(samp[uniform_index], uv, metal::bias(0.0));
metal::float4 _e176 = texture_array_bounded[uniform_index].inner.sample(samp[uniform_index].inner, uv, metal::bias(0.0));
metal::float4 _e177 = v4_;
v4_ = _e177 + _e176;
metal::float4 _e184 = texture_array_bounded[non_uniform_index].sample(samp[non_uniform_index], uv, metal::bias(0.0));
metal::float4 _e184 = texture_array_bounded[non_uniform_index].inner.sample(samp[non_uniform_index].inner, uv, metal::bias(0.0));
metal::float4 _e185 = v4_;
v4_ = _e185 + _e184;
float _e192 = texture_array_depth[0].sample_compare(samp_comp[0], uv, 0.0);
float _e192 = texture_array_depth[0].inner.sample_compare(samp_comp[0].inner, uv, 0.0);
float _e193 = v1_;
v1_ = _e193 + _e192;
float _e200 = texture_array_depth[uniform_index].sample_compare(samp_comp[uniform_index], uv, 0.0);
float _e200 = texture_array_depth[uniform_index].inner.sample_compare(samp_comp[uniform_index].inner, uv, 0.0);
float _e201 = v1_;
v1_ = _e201 + _e200;
float _e208 = texture_array_depth[non_uniform_index].sample_compare(samp_comp[non_uniform_index], uv, 0.0);
float _e208 = texture_array_depth[non_uniform_index].inner.sample_compare(samp_comp[non_uniform_index].inner, uv, 0.0);
float _e209 = v1_;
v1_ = _e209 + _e208;
float _e216 = texture_array_depth[0].sample_compare(samp_comp[0], uv, 0.0);
float _e216 = texture_array_depth[0].inner.sample_compare(samp_comp[0].inner, uv, 0.0);
float _e217 = v1_;
v1_ = _e217 + _e216;
float _e224 = texture_array_depth[uniform_index].sample_compare(samp_comp[uniform_index], uv, 0.0);
float _e224 = texture_array_depth[uniform_index].inner.sample_compare(samp_comp[uniform_index].inner, uv, 0.0);
float _e225 = v1_;
v1_ = _e225 + _e224;
float _e232 = texture_array_depth[non_uniform_index].sample_compare(samp_comp[non_uniform_index], uv, 0.0);
float _e232 = texture_array_depth[non_uniform_index].inner.sample_compare(samp_comp[non_uniform_index].inner, uv, 0.0);
float _e233 = v1_;
v1_ = _e233 + _e232;
metal::float4 _e239 = texture_array_bounded[0].sample(samp[0], uv, metal::gradient2d(uv, uv));
metal::float4 _e239 = texture_array_bounded[0].inner.sample(samp[0].inner, uv, metal::gradient2d(uv, uv));
metal::float4 _e240 = v4_;
v4_ = _e240 + _e239;
metal::float4 _e246 = texture_array_bounded[uniform_index].sample(samp[uniform_index], uv, metal::gradient2d(uv, uv));
metal::float4 _e246 = texture_array_bounded[uniform_index].inner.sample(samp[uniform_index].inner, uv, metal::gradient2d(uv, uv));
metal::float4 _e247 = v4_;
v4_ = _e247 + _e246;
metal::float4 _e253 = texture_array_bounded[non_uniform_index].sample(samp[non_uniform_index], uv, metal::gradient2d(uv, uv));
metal::float4 _e253 = texture_array_bounded[non_uniform_index].inner.sample(samp[non_uniform_index].inner, uv, metal::gradient2d(uv, uv));
metal::float4 _e254 = v4_;
v4_ = _e254 + _e253;
metal::float4 _e261 = texture_array_bounded[0].sample(samp[0], uv, metal::level(0.0));
metal::float4 _e261 = texture_array_bounded[0].inner.sample(samp[0].inner, uv, metal::level(0.0));
metal::float4 _e262 = v4_;
v4_ = _e262 + _e261;
metal::float4 _e269 = texture_array_bounded[uniform_index].sample(samp[uniform_index], uv, metal::level(0.0));
metal::float4 _e269 = texture_array_bounded[uniform_index].inner.sample(samp[uniform_index].inner, uv, metal::level(0.0));
metal::float4 _e270 = v4_;
v4_ = _e270 + _e269;
metal::float4 _e277 = texture_array_bounded[non_uniform_index].sample(samp[non_uniform_index], uv, metal::level(0.0));
metal::float4 _e277 = texture_array_bounded[non_uniform_index].inner.sample(samp[non_uniform_index].inner, uv, metal::level(0.0));
metal::float4 _e278 = v4_;
v4_ = _e278 + _e277;
metal::float4 _e282 = v4_;
texture_array_storage[0].write(_e282, metal::uint2(pix));
texture_array_storage[0].inner.write(_e282, metal::uint2(pix));
metal::float4 _e285 = v4_;
texture_array_storage[uniform_index].write(_e285, metal::uint2(pix));
texture_array_storage[uniform_index].inner.write(_e285, metal::uint2(pix));
metal::float4 _e288 = v4_;
texture_array_storage[non_uniform_index].write(_e288, metal::uint2(pix));
texture_array_storage[non_uniform_index].inner.write(_e288, metal::uint2(pix));
metal::uint2 _e289 = u2_;
uint _e290 = u1_;
metal::float2 v2_ = static_cast<metal::float2>(_e289 + metal::uint2(_e290));
Expand Down

0 comments on commit d439c66

Please sign in to comment.