diff --git a/src/compose/mod.rs b/src/compose/mod.rs index caa27e6..b679aff 100644 --- a/src/compose/mod.rs +++ b/src/compose/mod.rs @@ -1164,7 +1164,9 @@ impl Composer { let mut owned_types = HashSet::new(); for (h, ty) in source_ir.types.iter() { if let Some(name) = &ty.name { - if !name.contains(DECORATION_PRE) { + // we need to exclude autogenerated struct names, i.e. those that begin with "__" + // "__" is a reserved prefix for naga so user variables cannot use it. + if !name.contains(DECORATION_PRE) && !name.starts_with("__") { let name = format!("{module_decoration}{name}"); owned_types.insert(name.clone()); // copy and rename types diff --git a/src/compose/test.rs b/src/compose/test.rs index ff006e5..bed0462 100644 --- a/src/compose/test.rs +++ b/src/compose/test.rs @@ -1122,6 +1122,51 @@ mod test { output_eq!(wgsl, "tests/expected/use_shared_global.txt"); } + #[test] + fn test_atomics() { + let mut composer = Composer::default(); + + composer + .add_composable_module(ComposableModuleDescriptor { + source: include_str!("tests/atomics/mod.wgsl"), + file_path: "tests/atomics/mod.wgsl", + ..Default::default() + }) + .unwrap(); + + // TODO enable this test when HLSL support is available + if cfg!(feature = "test_shader") && false { + assert_eq!(test_shader(&mut composer), 28.0); + } + + let module = composer + .make_naga_module(NagaModuleDescriptor { + source: include_str!("tests/atomics/top.wgsl"), + file_path: "tests/atomics/top.wgsl", + ..Default::default() + }) + .unwrap(); + + let info = naga::valid::Validator::new( + naga::valid::ValidationFlags::all(), + naga::valid::Capabilities::default(), + ) + .validate(&module) + .unwrap(); + let wgsl = naga::back::wgsl::write_string( + &module, + &info, + naga::back::wgsl::WriterFlags::EXPLICIT_TYPES, + ) + .unwrap(); + + // let mut f = std::fs::File::create("atomics.txt").unwrap(); + // f.write_all(wgsl.as_bytes()).unwrap(); + // drop(f); + + output_eq!(wgsl, "tests/expected/atomics.txt"); + } + #[cfg(feature = "test_shader")] #[test] fn effective_defs() { diff --git a/src/compose/tests/atomics/mod.wgsl b/src/compose/tests/atomics/mod.wgsl new file mode 100644 index 0000000..53f96c3 --- /dev/null +++ b/src/compose/tests/atomics/mod.wgsl @@ -0,0 +1,19 @@ +#define_import_path test_module + +var atom: atomic; + +fn entry_point() -> f32 { + atomicStore(&atom, 1u); // atom = 1 + var y = atomicLoad(&atom); // y = 1, atom = 1 + y += atomicAdd(&atom, 2u); // y = 2, atom = 3 + y += atomicSub(&atom, 1u); // y = 5, atom = 2 + y += atomicMax(&atom, 5u); // y = 7, atom = 5 + y += atomicMin(&atom, 4u); // y = 12, atom = 4 + y += atomicExchange(&atom, y); // y = 16, atom = 12 + let exchange = atomicCompareExchangeWeak(&atom, 12u, 0u); + if exchange.exchanged { + y += exchange.old_value; // y = 28, atom = 0 + } + + return f32(y); // 28.0 +} \ No newline at end of file diff --git a/src/compose/tests/atomics/top.wgsl b/src/compose/tests/atomics/top.wgsl new file mode 100644 index 0000000..f2969ae --- /dev/null +++ b/src/compose/tests/atomics/top.wgsl @@ -0,0 +1,5 @@ +#import test_module + +fn main() -> f32 { + return test_module::entry_point(); +} \ No newline at end of file diff --git a/src/compose/tests/expected/atomics.txt b/src/compose/tests/expected/atomics.txt new file mode 100644 index 0000000..051f2e4 --- /dev/null +++ b/src/compose/tests/expected/atomics.txt @@ -0,0 +1,43 @@ +struct gen___atomic_compare_exchange_resultUint4_ { + old_value: u32, + exchanged: bool, +} + +var _naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom: atomic; + +fn _naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberentry_point() -> f32 { + var y: u32; + + atomicStore((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), 1u); + let _e3: u32 = atomicLoad((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom)); + y = _e3; + let _e7: u32 = atomicAdd((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), 2u); + let _e8: u32 = y; + y = (_e8 + _e7); + let _e12: u32 = atomicSub((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), 1u); + let _e13: u32 = y; + y = (_e13 + _e12); + let _e17: u32 = atomicMax((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), 5u); + let _e18: u32 = y; + y = (_e18 + _e17); + let _e22: u32 = atomicMin((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), 4u); + let _e23: u32 = y; + y = (_e23 + _e22); + let _e25: u32 = y; + let _e27: u32 = atomicExchange((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), _e25); + let _e28: u32 = y; + y = (_e28 + _e27); + let _e33: gen___atomic_compare_exchange_resultUint4_ = atomicCompareExchangeWeak((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), 12u, 0u); + if _e33.exchanged { + let _e36: u32 = y; + y = (_e36 + _e33.old_value); + } + let _e38: u32 = y; + return f32(_e38); +} + +fn main() -> f32 { + let _e0: f32 = _naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberentry_point(); + return _e0; +} + diff --git a/src/derive.rs b/src/derive.rs index 3ae887a..2cd9d45 100644 --- a/src/derive.rs +++ b/src/derive.rs @@ -1,8 +1,8 @@ use indexmap::IndexMap; use naga::{ - Arena, Block, Constant, EntryPoint, Expression, Function, FunctionArgument, FunctionResult, - GlobalVariable, Handle, ImageQuery, LocalVariable, Module, SampleLevel, Span, Statement, - StructMember, SwitchCase, Type, TypeInner, UniqueArena, + Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument, + FunctionResult, GlobalVariable, Handle, ImageQuery, LocalVariable, Module, SampleLevel, Span, + Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena, }; use std::{cell::RefCell, collections::HashMap, rc::Rc}; @@ -322,12 +322,22 @@ impl<'a> DerivedModule<'a> { fun, value, result, - } => Statement::Atomic { - pointer: map_expr!(pointer), - fun: *fun, - value: map_expr!(value), - result: map_expr!(result), - }, + } => { + let fun = match fun { + AtomicFunction::Exchange { + compare: Some(compare_expr), + } => AtomicFunction::Exchange { + compare: Some(map_expr!(compare_expr)), + }, + fun => *fun, + }; + Statement::Atomic { + pointer: map_expr!(pointer), + fun, + value: map_expr!(value), + result: map_expr!(result), + } + } Statement::WorkGroupUniformLoad { pointer, result } => { Statement::WorkGroupUniformLoad { pointer: map_expr!(pointer), @@ -568,8 +578,15 @@ impl<'a> DerivedModule<'a> { expr.clone() } - Expression::AtomicResult { .. } => expr.clone(), - Expression::WorkGroupUniformLoadResult { .. } => expr.clone(), + Expression::AtomicResult { ty, comparison } => Expression::AtomicResult { + ty: self.import_type(ty), + comparison: *comparison, + }, + Expression::WorkGroupUniformLoadResult { ty } => { + Expression::WorkGroupUniformLoadResult { + ty: self.import_type(ty), + } + } Expression::RayQueryProceedResult => expr.clone(), Expression::RayQueryGetIntersection { query, committed } => { Expression::RayQueryGetIntersection {