Skip to content

Commit

Permalink
Correct atomic mapping (#49)
Browse files Browse the repository at this point in the history
Addresses #48

---------

Co-authored-by: robtfm <[email protected]>
  • Loading branch information
Joeoc2001 and robtfm authored Sep 23, 2023
1 parent b2cf1bc commit 58e2272
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 12 deletions.
4 changes: 3 additions & 1 deletion src/compose/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,9 @@ impl Composer {
let mut owned_types = HashSet::new();
for (h, ty) in source_ir.types.iter() {
if let Some(name) = &ty.name {
if !name.contains(DECORATION_PRE) {
// we need to exclude autogenerated struct names, i.e. those that begin with "__"
// "__" is a reserved prefix for naga so user variables cannot use it.
if !name.contains(DECORATION_PRE) && !name.starts_with("__") {
let name = format!("{module_decoration}{name}");
owned_types.insert(name.clone());
// copy and rename types
Expand Down
45 changes: 45 additions & 0 deletions src/compose/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,51 @@ mod test {
output_eq!(wgsl, "tests/expected/use_shared_global.txt");
}

#[test]
fn test_atomics() {
let mut composer = Composer::default();

composer
.add_composable_module(ComposableModuleDescriptor {
source: include_str!("tests/atomics/mod.wgsl"),
file_path: "tests/atomics/mod.wgsl",
..Default::default()
})
.unwrap();

// TODO enable this test when HLSL support is available
if cfg!(feature = "test_shader") && false {
assert_eq!(test_shader(&mut composer), 28.0);
}

let module = composer
.make_naga_module(NagaModuleDescriptor {
source: include_str!("tests/atomics/top.wgsl"),
file_path: "tests/atomics/top.wgsl",
..Default::default()
})
.unwrap();

let info = naga::valid::Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::default(),
)
.validate(&module)
.unwrap();
let wgsl = naga::back::wgsl::write_string(
&module,
&info,
naga::back::wgsl::WriterFlags::EXPLICIT_TYPES,
)
.unwrap();

// let mut f = std::fs::File::create("atomics.txt").unwrap();
// f.write_all(wgsl.as_bytes()).unwrap();
// drop(f);

output_eq!(wgsl, "tests/expected/atomics.txt");
}

#[cfg(feature = "test_shader")]
#[test]
fn effective_defs() {
Expand Down
19 changes: 19 additions & 0 deletions src/compose/tests/atomics/mod.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#define_import_path test_module

var<workgroup> atom: atomic<u32>;

fn entry_point() -> f32 {
atomicStore(&atom, 1u); // atom = 1
var y = atomicLoad(&atom); // y = 1, atom = 1
y += atomicAdd(&atom, 2u); // y = 2, atom = 3
y += atomicSub(&atom, 1u); // y = 5, atom = 2
y += atomicMax(&atom, 5u); // y = 7, atom = 5
y += atomicMin(&atom, 4u); // y = 12, atom = 4
y += atomicExchange(&atom, y); // y = 16, atom = 12
let exchange = atomicCompareExchangeWeak(&atom, 12u, 0u);
if exchange.exchanged {
y += exchange.old_value; // y = 28, atom = 0
}

return f32(y); // 28.0
}
5 changes: 5 additions & 0 deletions src/compose/tests/atomics/top.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#import test_module

fn main() -> f32 {
return test_module::entry_point();
}
43 changes: 43 additions & 0 deletions src/compose/tests/expected/atomics.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
struct gen___atomic_compare_exchange_resultUint4_ {
old_value: u32,
exchanged: bool,
}

var<workgroup> _naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom: atomic<u32>;

fn _naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberentry_point() -> f32 {
var y: u32;

atomicStore((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), 1u);
let _e3: u32 = atomicLoad((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom));
y = _e3;
let _e7: u32 = atomicAdd((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), 2u);
let _e8: u32 = y;
y = (_e8 + _e7);
let _e12: u32 = atomicSub((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), 1u);
let _e13: u32 = y;
y = (_e13 + _e12);
let _e17: u32 = atomicMax((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), 5u);
let _e18: u32 = y;
y = (_e18 + _e17);
let _e22: u32 = atomicMin((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), 4u);
let _e23: u32 = y;
y = (_e23 + _e22);
let _e25: u32 = y;
let _e27: u32 = atomicExchange((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), _e25);
let _e28: u32 = y;
y = (_e28 + _e27);
let _e33: gen___atomic_compare_exchange_resultUint4_ = atomicCompareExchangeWeak((&_naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberatom), 12u, 0u);
if _e33.exchanged {
let _e36: u32 = y;
y = (_e36 + _e33.old_value);
}
let _e38: u32 = y;
return f32(_e38);
}

fn main() -> f32 {
let _e0: f32 = _naga_oil_mod_ORSXG5C7NVXWI5LMMU_memberentry_point();
return _e0;
}

39 changes: 28 additions & 11 deletions src/derive.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use indexmap::IndexMap;
use naga::{
Arena, Block, Constant, EntryPoint, Expression, Function, FunctionArgument, FunctionResult,
GlobalVariable, Handle, ImageQuery, LocalVariable, Module, SampleLevel, Span, Statement,
StructMember, SwitchCase, Type, TypeInner, UniqueArena,
Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument,
FunctionResult, GlobalVariable, Handle, ImageQuery, LocalVariable, Module, SampleLevel, Span,
Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena,
};
use std::{cell::RefCell, collections::HashMap, rc::Rc};

Expand Down Expand Up @@ -322,12 +322,22 @@ impl<'a> DerivedModule<'a> {
fun,
value,
result,
} => Statement::Atomic {
pointer: map_expr!(pointer),
fun: *fun,
value: map_expr!(value),
result: map_expr!(result),
},
} => {
let fun = match fun {
AtomicFunction::Exchange {
compare: Some(compare_expr),
} => AtomicFunction::Exchange {
compare: Some(map_expr!(compare_expr)),
},
fun => *fun,
};
Statement::Atomic {
pointer: map_expr!(pointer),
fun,
value: map_expr!(value),
result: map_expr!(result),
}
}
Statement::WorkGroupUniformLoad { pointer, result } => {
Statement::WorkGroupUniformLoad {
pointer: map_expr!(pointer),
Expand Down Expand Up @@ -568,8 +578,15 @@ impl<'a> DerivedModule<'a> {
expr.clone()
}

Expression::AtomicResult { .. } => expr.clone(),
Expression::WorkGroupUniformLoadResult { .. } => expr.clone(),
Expression::AtomicResult { ty, comparison } => Expression::AtomicResult {
ty: self.import_type(ty),
comparison: *comparison,
},
Expression::WorkGroupUniformLoadResult { ty } => {
Expression::WorkGroupUniformLoadResult {
ty: self.import_type(ty),
}
}
Expression::RayQueryProceedResult => expr.clone(),
Expression::RayQueryGetIntersection { query, committed } => {
Expression::RayQueryGetIntersection {
Expand Down

0 comments on commit 58e2272

Please sign in to comment.