From 4d16ed642b089a6fd80e8b0e08729720907f787c Mon Sep 17 00:00:00 2001 From: Elabajaba Date: Fri, 10 May 2024 23:53:34 -0400 Subject: [PATCH] maybe working subgroup ops --- src/compose/error.rs | 3 +- src/compose/test.rs | 3 +- src/derive.rs | 84 ++++++++++++++++++++++++++++++++++++++------ 3 files changed, 77 insertions(+), 13 deletions(-) diff --git a/src/compose/error.rs b/src/compose/error.rs index 43b7e69..171e3f3 100644 --- a/src/compose/error.rs +++ b/src/compose/error.rs @@ -226,7 +226,8 @@ impl ComposerError { ), #[cfg(feature = "glsl")] ComposerErrorInner::GlslParseError(e) => ( - e.errors.iter() + e.errors + .iter() .map(|naga::front::glsl::Error { kind, meta }| { Label::primary((), map_span(meta.to_range().unwrap_or(0..0))) .with_message(kind.to_string()) diff --git a/src/compose/test.rs b/src/compose/test.rs index 04eed78..de27cdc 100644 --- a/src/compose/test.rs +++ b/src/compose/test.rs @@ -1448,8 +1448,7 @@ mod test { ), module: &shader_module, entry_point: "run_test", - // TODO: Probably wrong - constants: &HashMap::new(), + compilation_options: Default::default(), }); let bindgroup = device.create_bind_group(&BindGroupDescriptor { diff --git a/src/derive.rs b/src/derive.rs index 1dfdb05..692a5dc 100644 --- a/src/derive.rs +++ b/src/derive.rs @@ -1,8 +1,8 @@ use indexmap::IndexMap; use naga::{ Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument, - FunctionResult, GlobalVariable, Handle, ImageQuery, LocalVariable, Module, SampleLevel, Span, - Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena, + FunctionResult, GatherMode, GlobalVariable, Handle, ImageQuery, LocalVariable, Module, + Override, SampleLevel, Span, Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena, }; use std::{cell::RefCell, rc::Rc}; @@ -368,9 +368,69 @@ impl<'a> DerivedModule<'a> { | Statement::Continue | Statement::Kill | Statement::Barrier(_) => stmt.clone(), - Statement::SubgroupBallot { result, predicate } => todo!(), - Statement::SubgroupGather { mode, argument, result } => todo!(), - Statement::SubgroupCollectiveOperation { op, collective_op, argument, result } => todo!(), + Statement::SubgroupBallot { result, predicate } => Statement::SubgroupBallot { + result: map_expr!(result), + predicate: map_expr_opt!(predicate), + }, + Statement::SubgroupGather { + mode, + argument, + result, + } => match mode { + GatherMode::BroadcastFirst => { + return Statement::SubgroupGather { + mode: *mode, + argument: map_expr!(argument), + result: map_expr!(result), + } + } + GatherMode::Broadcast(hnd) => { + return Statement::SubgroupGather { + mode: GatherMode::Broadcast(map_expr!(hnd)), + argument: map_expr!(argument), + result: map_expr!(result), + } + } + GatherMode::Shuffle(hnd) => { + return Statement::SubgroupGather { + mode: GatherMode::Shuffle(map_expr!(hnd)), + argument: map_expr!(argument), + result: map_expr!(result), + } + } + GatherMode::ShuffleDown(hnd) => { + return Statement::SubgroupGather { + mode: GatherMode::ShuffleDown(map_expr!(hnd)), + argument: map_expr!(argument), + result: map_expr!(result), + } + } + GatherMode::ShuffleUp(hnd) => { + return Statement::SubgroupGather { + mode: GatherMode::ShuffleUp(map_expr!(hnd)), + argument: map_expr!(argument), + result: map_expr!(result), + } + } + GatherMode::ShuffleXor(hnd) => { + return Statement::SubgroupGather { + mode: GatherMode::ShuffleXor(map_expr!(hnd)), + argument: map_expr!(argument), + result: map_expr!(result), + } + } + }, + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => Statement::SubgroupCollectiveOperation { + op: *op, + collective_op: *collective_op, + argument: map_expr!(argument), + result: map_expr!(result), + }, } }) .collect(); @@ -594,10 +654,14 @@ impl<'a> DerivedModule<'a> { committed: *committed, } } - // TODO: Probably wrong - Expression::Override(_) => expr.clone(), - Expression::SubgroupBallotResult => todo!(), - Expression::SubgroupOperationResult { ty } => todo!(), + Expression::Override(_) => { + // TODO: unsure if this is correct + expr.clone() + } + Expression::SubgroupBallotResult => expr.clone(), + Expression::SubgroupOperationResult { ty } => Expression::SubgroupOperationResult { + ty: self.import_type(ty), + }, }; if !non_emitting_only || is_external { @@ -754,7 +818,7 @@ impl<'a> From> for naga::Module { types: derived.types, constants: derived.constants, global_variables: derived.globals, - // TODO: Maybe wrong + // TODO: Need to also include override expressions global_expressions: Rc::try_unwrap(derived.const_expressions) .unwrap() .into_inner(),