From d439c6643eeebef32e0f100420efd44d45c355fe Mon Sep 17 00:00:00 2001 From: Connor Fitzgerald Date: Mon, 16 Dec 2024 00:49:15 -0500 Subject: [PATCH] Convert Naga's MSL Backend to Generating Argument Buffers for Binding Arrays --- naga/src/back/msl/keywords.rs | 1 + naga/src/back/msl/mod.rs | 2 - naga/src/back/msl/writer.rs | 53 ++++++++----- naga/tests/in/binding-arrays.param.ron | 4 +- naga/tests/out/msl/binding-arrays.msl | 106 +++++++++++++------------ 5 files changed, 92 insertions(+), 74 deletions(-) diff --git a/naga/src/back/msl/keywords.rs b/naga/src/back/msl/keywords.rs index 73c457dd349..a4eabab234c 100644 --- a/naga/src/back/msl/keywords.rs +++ b/naga/src/back/msl/keywords.rs @@ -341,4 +341,5 @@ pub const RESERVED: &[&str] = &[ "DefaultConstructible", super::writer::FREXP_FUNCTION, super::writer::MODF_FUNCTION, + super::writer::ARGUMENT_BUFFER_WRAPPER_STRUCT, ]; diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 453b7136b87..b2c61fff718 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -59,8 +59,6 @@ pub struct BindTarget { pub buffer: Option, pub texture: Option, pub sampler: Option, - /// If the binding is an unsized binding array, this overrides the size. - pub binding_array_size: Option, pub mutable: bool, } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index c1198238004..4594f47f752 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -36,6 +36,14 @@ const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type"; pub(crate) const ATOMIC_COMP_EXCH_FUNCTION: &str = "naga_atomic_compare_exchange_weak_explicit"; pub(crate) const MODF_FUNCTION: &str = "naga_modf"; pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; +/// For some reason, Metal does not let you have `metal::texture<..>*` as a buffer argument. +/// However, if you put that texture inside a struct, everything is totally fine. This +/// baffles me to no end. +/// +/// As such, we wrap all argument buffers in a struct that has a single generic field. +/// This allows `NagaArgumentBufferWrapper>*` to work. The astute among +/// you have noticed that this should be exactly the same to the compiler, and you're correct. +pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapper"; /// Write the Metal name for a Naga numeric type: scalar, vector, or matrix. /// @@ -275,24 +283,17 @@ impl Display for TypeContext<'_> { crate::TypeInner::RayQuery => { write!(out, "{RAY_QUERY_TYPE}") } - crate::TypeInner::BindingArray { base, size } => { + crate::TypeInner::BindingArray { base, .. } => { let base_tyname = Self { handle: base, first_time: false, ..*self }; - if let Some(&super::ResolvedBinding::Resource(super::BindTarget { - binding_array_size: Some(override_size), - .. - })) = self.binding - { - write!(out, "{NAMESPACE}::array<{base_tyname}, {override_size}>") - } else if let crate::ArraySize::Constant(size) = size { - write!(out, "{NAMESPACE}::array<{base_tyname}, {size}>") - } else { - unreachable!("metal requires all arrays be constant sized"); - } + write!( + out, + "constant {ARGUMENT_BUFFER_WRAPPER_STRUCT}<{base_tyname}>*" + ) } } } @@ -2549,6 +2550,8 @@ impl Writer { } => true, _ => false, }; + let accessing_wrapped_binding_array = + matches!(*base_ty, crate::TypeInner::BindingArray { .. }); self.put_access_chain(base, policy, context)?; if accessing_wrapped_array { @@ -2585,6 +2588,10 @@ impl Writer { write!(self.out, "]")?; + if accessing_wrapped_binding_array { + write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?; + } + Ok(()) } @@ -3696,7 +3703,18 @@ impl Writer { } fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult { + let mut generated_argument_buffer_wrapper = false; for (handle, ty) in module.types.iter() { + if let crate::TypeInner::BindingArray { .. } = ty.inner { + if !generated_argument_buffer_wrapper { + writeln!(self.out, "template ")?; + writeln!(self.out, "struct {ARGUMENT_BUFFER_WRAPPER_STRUCT} {{")?; + writeln!(self.out, "{}T {WRAPPED_ARRAY_FIELD};", back::INDENT)?; + writeln!(self.out, "}};")?; + generated_argument_buffer_wrapper = true; + } + } + if !ty.needs_alias() { continue; } @@ -4995,13 +5013,10 @@ template let target = options.get_resource_binding_target(ep, br); let good = match target { Some(target) => { - let binding_ty = match module.types[var.ty].inner { - crate::TypeInner::BindingArray { base, .. } => { - &module.types[base].inner - } - ref ty => ty, - }; - match *binding_ty { + // We intentionally don't dereference binding_arrays here, + // so that binding arrays fall to the buffer location. + + match module.types[var.ty].inner { crate::TypeInner::Image { .. } => target.texture.is_some(), crate::TypeInner::Sampler { .. } => { target.sampler.is_some() diff --git a/naga/tests/in/binding-arrays.param.ron b/naga/tests/in/binding-arrays.param.ron index 249a4afe2ae..96807d825a5 100644 --- a/naga/tests/in/binding-arrays.param.ron +++ b/naga/tests/in/binding-arrays.param.ron @@ -19,11 +19,11 @@ restrict_indexing: true ), msl: ( - lang_version: (2, 0), + lang_version: (3, 0), per_entry_point_map: { "main": ( resources: { - (group: 0, binding: 0): (texture: Some(0), binding_array_size: Some(10), mutable: false), + (group: 0, binding: 0): (buffer: Some(0), binding_array_size: Some(10), mutable: false), }, sizes_buffer: None, ) diff --git a/naga/tests/out/msl/binding-arrays.msl b/naga/tests/out/msl/binding-arrays.msl index 75f787a9f20..f62546241aa 100644 --- a/naga/tests/out/msl/binding-arrays.msl +++ b/naga/tests/out/msl/binding-arrays.msl @@ -1,4 +1,4 @@ -// language: metal2.0 +// language: metal3.0 #include #include @@ -13,6 +13,10 @@ struct DefaultConstructible { struct UniformIndex { uint index; }; +template +struct NagaArgumentBufferWrapper { + T inner; +}; struct FragmentIn { uint index; }; @@ -25,14 +29,14 @@ struct main_Output { }; fragment main_Output main_( main_Input varyings [[stage_in]] -, metal::array, 10> texture_array_unbounded [[texture(0)]] -, metal::array, 5> texture_array_bounded [[user(fake0)]] -, metal::array, 5> texture_array_2darray [[user(fake0)]] -, metal::array, 5> texture_array_multisampled [[user(fake0)]] -, metal::array, 5> texture_array_depth [[user(fake0)]] -, metal::array, 5> texture_array_storage [[user(fake0)]] -, metal::array samp [[user(fake0)]] -, metal::array samp_comp [[user(fake0)]] +, constant NagaArgumentBufferWrapper>* texture_array_unbounded [[buffer(0)]] +, constant NagaArgumentBufferWrapper>* texture_array_bounded [[user(fake0)]] +, constant NagaArgumentBufferWrapper>* texture_array_2darray [[user(fake0)]] +, constant NagaArgumentBufferWrapper>* texture_array_multisampled [[user(fake0)]] +, constant NagaArgumentBufferWrapper>* texture_array_depth [[user(fake0)]] +, constant NagaArgumentBufferWrapper>* texture_array_storage [[user(fake0)]] +, constant NagaArgumentBufferWrapper* samp [[user(fake0)]] +, constant NagaArgumentBufferWrapper* samp_comp [[user(fake0)]] , constant UniformIndex& uni [[user(fake0)]] ) { const FragmentIn fragment_in = { varyings.index }; @@ -45,116 +49,116 @@ fragment main_Output main_( metal::float2 uv = metal::float2(0.0); metal::int2 pix = metal::int2(0); metal::uint2 _e22 = u2_; - u2_ = _e22 + metal::uint2(texture_array_unbounded[0].get_width(), texture_array_unbounded[0].get_height()); + u2_ = _e22 + metal::uint2(texture_array_unbounded[0].inner.get_width(), texture_array_unbounded[0].inner.get_height()); metal::uint2 _e27 = u2_; - u2_ = _e27 + metal::uint2(texture_array_unbounded[uniform_index].get_width(), texture_array_unbounded[uniform_index].get_height()); + u2_ = _e27 + metal::uint2(texture_array_unbounded[uniform_index].inner.get_width(), texture_array_unbounded[uniform_index].inner.get_height()); metal::uint2 _e32 = u2_; - u2_ = _e32 + metal::uint2(texture_array_unbounded[non_uniform_index].get_width(), texture_array_unbounded[non_uniform_index].get_height()); - metal::float4 _e38 = texture_array_bounded[0].gather(samp[0], uv); + u2_ = _e32 + metal::uint2(texture_array_unbounded[non_uniform_index].inner.get_width(), texture_array_unbounded[non_uniform_index].inner.get_height()); + metal::float4 _e38 = texture_array_bounded[0].inner.gather(samp[0].inner, uv); metal::float4 _e39 = v4_; v4_ = _e39 + _e38; - metal::float4 _e45 = texture_array_bounded[uniform_index].gather(samp[uniform_index], uv); + metal::float4 _e45 = texture_array_bounded[uniform_index].inner.gather(samp[uniform_index].inner, uv); metal::float4 _e46 = v4_; v4_ = _e46 + _e45; - metal::float4 _e52 = texture_array_bounded[non_uniform_index].gather(samp[non_uniform_index], uv); + metal::float4 _e52 = texture_array_bounded[non_uniform_index].inner.gather(samp[non_uniform_index].inner, uv); metal::float4 _e53 = v4_; v4_ = _e53 + _e52; - metal::float4 _e60 = texture_array_depth[0].gather_compare(samp_comp[0], uv, 0.0); + metal::float4 _e60 = texture_array_depth[0].inner.gather_compare(samp_comp[0].inner, uv, 0.0); metal::float4 _e61 = v4_; v4_ = _e61 + _e60; - metal::float4 _e68 = texture_array_depth[uniform_index].gather_compare(samp_comp[uniform_index], uv, 0.0); + metal::float4 _e68 = texture_array_depth[uniform_index].inner.gather_compare(samp_comp[uniform_index].inner, uv, 0.0); metal::float4 _e69 = v4_; v4_ = _e69 + _e68; - metal::float4 _e76 = texture_array_depth[non_uniform_index].gather_compare(samp_comp[non_uniform_index], uv, 0.0); + metal::float4 _e76 = texture_array_depth[non_uniform_index].inner.gather_compare(samp_comp[non_uniform_index].inner, uv, 0.0); metal::float4 _e77 = v4_; v4_ = _e77 + _e76; - metal::float4 _e82 = (uint(0) < texture_array_unbounded[0].get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[0].get_width(0), texture_array_unbounded[0].get_height(0))) ? texture_array_unbounded[0].read(metal::uint2(pix), 0): DefaultConstructible()); + metal::float4 _e82 = (uint(0) < texture_array_unbounded[0].inner.get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[0].inner.get_width(0), texture_array_unbounded[0].inner.get_height(0))) ? texture_array_unbounded[0].inner.read(metal::uint2(pix), 0): DefaultConstructible()); metal::float4 _e83 = v4_; v4_ = _e83 + _e82; - metal::float4 _e88 = (uint(0) < texture_array_unbounded[uniform_index].get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[uniform_index].get_width(0), texture_array_unbounded[uniform_index].get_height(0))) ? texture_array_unbounded[uniform_index].read(metal::uint2(pix), 0): DefaultConstructible()); + metal::float4 _e88 = (uint(0) < texture_array_unbounded[uniform_index].inner.get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[uniform_index].inner.get_width(0), texture_array_unbounded[uniform_index].inner.get_height(0))) ? texture_array_unbounded[uniform_index].inner.read(metal::uint2(pix), 0): DefaultConstructible()); metal::float4 _e89 = v4_; v4_ = _e89 + _e88; - metal::float4 _e94 = (uint(0) < texture_array_unbounded[non_uniform_index].get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[non_uniform_index].get_width(0), texture_array_unbounded[non_uniform_index].get_height(0))) ? texture_array_unbounded[non_uniform_index].read(metal::uint2(pix), 0): DefaultConstructible()); + metal::float4 _e94 = (uint(0) < texture_array_unbounded[non_uniform_index].inner.get_num_mip_levels() && metal::all(metal::uint2(pix) < metal::uint2(texture_array_unbounded[non_uniform_index].inner.get_width(0), texture_array_unbounded[non_uniform_index].inner.get_height(0))) ? texture_array_unbounded[non_uniform_index].inner.read(metal::uint2(pix), 0): DefaultConstructible()); metal::float4 _e95 = v4_; v4_ = _e95 + _e94; uint _e100 = u1_; - u1_ = _e100 + texture_array_2darray[0].get_array_size(); + u1_ = _e100 + texture_array_2darray[0].inner.get_array_size(); uint _e105 = u1_; - u1_ = _e105 + texture_array_2darray[uniform_index].get_array_size(); + u1_ = _e105 + texture_array_2darray[uniform_index].inner.get_array_size(); uint _e110 = u1_; - u1_ = _e110 + texture_array_2darray[non_uniform_index].get_array_size(); + u1_ = _e110 + texture_array_2darray[non_uniform_index].inner.get_array_size(); uint _e115 = u1_; - u1_ = _e115 + texture_array_bounded[0].get_num_mip_levels(); + u1_ = _e115 + texture_array_bounded[0].inner.get_num_mip_levels(); uint _e120 = u1_; - u1_ = _e120 + texture_array_bounded[uniform_index].get_num_mip_levels(); + u1_ = _e120 + texture_array_bounded[uniform_index].inner.get_num_mip_levels(); uint _e125 = u1_; - u1_ = _e125 + texture_array_bounded[non_uniform_index].get_num_mip_levels(); + u1_ = _e125 + texture_array_bounded[non_uniform_index].inner.get_num_mip_levels(); uint _e130 = u1_; - u1_ = _e130 + texture_array_multisampled[0].get_num_samples(); + u1_ = _e130 + texture_array_multisampled[0].inner.get_num_samples(); uint _e135 = u1_; - u1_ = _e135 + texture_array_multisampled[uniform_index].get_num_samples(); + u1_ = _e135 + texture_array_multisampled[uniform_index].inner.get_num_samples(); uint _e140 = u1_; - u1_ = _e140 + texture_array_multisampled[non_uniform_index].get_num_samples(); - metal::float4 _e146 = texture_array_bounded[0].sample(samp[0], uv); + u1_ = _e140 + texture_array_multisampled[non_uniform_index].inner.get_num_samples(); + metal::float4 _e146 = texture_array_bounded[0].inner.sample(samp[0].inner, uv); metal::float4 _e147 = v4_; v4_ = _e147 + _e146; - metal::float4 _e153 = texture_array_bounded[uniform_index].sample(samp[uniform_index], uv); + metal::float4 _e153 = texture_array_bounded[uniform_index].inner.sample(samp[uniform_index].inner, uv); metal::float4 _e154 = v4_; v4_ = _e154 + _e153; - metal::float4 _e160 = texture_array_bounded[non_uniform_index].sample(samp[non_uniform_index], uv); + metal::float4 _e160 = texture_array_bounded[non_uniform_index].inner.sample(samp[non_uniform_index].inner, uv); metal::float4 _e161 = v4_; v4_ = _e161 + _e160; - metal::float4 _e168 = texture_array_bounded[0].sample(samp[0], uv, metal::bias(0.0)); + metal::float4 _e168 = texture_array_bounded[0].inner.sample(samp[0].inner, uv, metal::bias(0.0)); metal::float4 _e169 = v4_; v4_ = _e169 + _e168; - metal::float4 _e176 = texture_array_bounded[uniform_index].sample(samp[uniform_index], uv, metal::bias(0.0)); + metal::float4 _e176 = texture_array_bounded[uniform_index].inner.sample(samp[uniform_index].inner, uv, metal::bias(0.0)); metal::float4 _e177 = v4_; v4_ = _e177 + _e176; - metal::float4 _e184 = texture_array_bounded[non_uniform_index].sample(samp[non_uniform_index], uv, metal::bias(0.0)); + metal::float4 _e184 = texture_array_bounded[non_uniform_index].inner.sample(samp[non_uniform_index].inner, uv, metal::bias(0.0)); metal::float4 _e185 = v4_; v4_ = _e185 + _e184; - float _e192 = texture_array_depth[0].sample_compare(samp_comp[0], uv, 0.0); + float _e192 = texture_array_depth[0].inner.sample_compare(samp_comp[0].inner, uv, 0.0); float _e193 = v1_; v1_ = _e193 + _e192; - float _e200 = texture_array_depth[uniform_index].sample_compare(samp_comp[uniform_index], uv, 0.0); + float _e200 = texture_array_depth[uniform_index].inner.sample_compare(samp_comp[uniform_index].inner, uv, 0.0); float _e201 = v1_; v1_ = _e201 + _e200; - float _e208 = texture_array_depth[non_uniform_index].sample_compare(samp_comp[non_uniform_index], uv, 0.0); + float _e208 = texture_array_depth[non_uniform_index].inner.sample_compare(samp_comp[non_uniform_index].inner, uv, 0.0); float _e209 = v1_; v1_ = _e209 + _e208; - float _e216 = texture_array_depth[0].sample_compare(samp_comp[0], uv, 0.0); + float _e216 = texture_array_depth[0].inner.sample_compare(samp_comp[0].inner, uv, 0.0); float _e217 = v1_; v1_ = _e217 + _e216; - float _e224 = texture_array_depth[uniform_index].sample_compare(samp_comp[uniform_index], uv, 0.0); + float _e224 = texture_array_depth[uniform_index].inner.sample_compare(samp_comp[uniform_index].inner, uv, 0.0); float _e225 = v1_; v1_ = _e225 + _e224; - float _e232 = texture_array_depth[non_uniform_index].sample_compare(samp_comp[non_uniform_index], uv, 0.0); + float _e232 = texture_array_depth[non_uniform_index].inner.sample_compare(samp_comp[non_uniform_index].inner, uv, 0.0); float _e233 = v1_; v1_ = _e233 + _e232; - metal::float4 _e239 = texture_array_bounded[0].sample(samp[0], uv, metal::gradient2d(uv, uv)); + metal::float4 _e239 = texture_array_bounded[0].inner.sample(samp[0].inner, uv, metal::gradient2d(uv, uv)); metal::float4 _e240 = v4_; v4_ = _e240 + _e239; - metal::float4 _e246 = texture_array_bounded[uniform_index].sample(samp[uniform_index], uv, metal::gradient2d(uv, uv)); + metal::float4 _e246 = texture_array_bounded[uniform_index].inner.sample(samp[uniform_index].inner, uv, metal::gradient2d(uv, uv)); metal::float4 _e247 = v4_; v4_ = _e247 + _e246; - metal::float4 _e253 = texture_array_bounded[non_uniform_index].sample(samp[non_uniform_index], uv, metal::gradient2d(uv, uv)); + metal::float4 _e253 = texture_array_bounded[non_uniform_index].inner.sample(samp[non_uniform_index].inner, uv, metal::gradient2d(uv, uv)); metal::float4 _e254 = v4_; v4_ = _e254 + _e253; - metal::float4 _e261 = texture_array_bounded[0].sample(samp[0], uv, metal::level(0.0)); + metal::float4 _e261 = texture_array_bounded[0].inner.sample(samp[0].inner, uv, metal::level(0.0)); metal::float4 _e262 = v4_; v4_ = _e262 + _e261; - metal::float4 _e269 = texture_array_bounded[uniform_index].sample(samp[uniform_index], uv, metal::level(0.0)); + metal::float4 _e269 = texture_array_bounded[uniform_index].inner.sample(samp[uniform_index].inner, uv, metal::level(0.0)); metal::float4 _e270 = v4_; v4_ = _e270 + _e269; - metal::float4 _e277 = texture_array_bounded[non_uniform_index].sample(samp[non_uniform_index], uv, metal::level(0.0)); + metal::float4 _e277 = texture_array_bounded[non_uniform_index].inner.sample(samp[non_uniform_index].inner, uv, metal::level(0.0)); metal::float4 _e278 = v4_; v4_ = _e278 + _e277; metal::float4 _e282 = v4_; - texture_array_storage[0].write(_e282, metal::uint2(pix)); + texture_array_storage[0].inner.write(_e282, metal::uint2(pix)); metal::float4 _e285 = v4_; - texture_array_storage[uniform_index].write(_e285, metal::uint2(pix)); + texture_array_storage[uniform_index].inner.write(_e285, metal::uint2(pix)); metal::float4 _e288 = v4_; - texture_array_storage[non_uniform_index].write(_e288, metal::uint2(pix)); + texture_array_storage[non_uniform_index].inner.write(_e288, metal::uint2(pix)); metal::uint2 _e289 = u2_; uint _e290 = u1_; metal::float2 v2_ = static_cast(_e289 + metal::uint2(_e290));