Skip to content

Commit

Permalink
maybe working subgroup ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Elabajaba committed May 11, 2024
1 parent 192a95d commit 4d16ed6
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 13 deletions.
3 changes: 2 additions & 1 deletion src/compose/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ impl ComposerError {
),
#[cfg(feature = "glsl")]
ComposerErrorInner::GlslParseError(e) => (
e.errors.iter()
e.errors
.iter()
.map(|naga::front::glsl::Error { kind, meta }| {
Label::primary((), map_span(meta.to_range().unwrap_or(0..0)))
.with_message(kind.to_string())
Expand Down
3 changes: 1 addition & 2 deletions src/compose/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1448,8 +1448,7 @@ mod test {
),
module: &shader_module,
entry_point: "run_test",
// TODO: Probably wrong
constants: &HashMap::new(),
compilation_options: Default::default(),
});

let bindgroup = device.create_bind_group(&BindGroupDescriptor {
Expand Down
84 changes: 74 additions & 10 deletions src/derive.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use indexmap::IndexMap;
use naga::{
Arena, AtomicFunction, Block, Constant, EntryPoint, Expression, Function, FunctionArgument,
FunctionResult, GlobalVariable, Handle, ImageQuery, LocalVariable, Module, SampleLevel, Span,
Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena,
FunctionResult, GatherMode, GlobalVariable, Handle, ImageQuery, LocalVariable, Module,
Override, SampleLevel, Span, Statement, StructMember, SwitchCase, Type, TypeInner, UniqueArena,
};
use std::{cell::RefCell, rc::Rc};

Expand Down Expand Up @@ -368,9 +368,69 @@ impl<'a> DerivedModule<'a> {
| Statement::Continue
| Statement::Kill
| Statement::Barrier(_) => stmt.clone(),
Statement::SubgroupBallot { result, predicate } => todo!(),
Statement::SubgroupGather { mode, argument, result } => todo!(),
Statement::SubgroupCollectiveOperation { op, collective_op, argument, result } => todo!(),
Statement::SubgroupBallot { result, predicate } => Statement::SubgroupBallot {
result: map_expr!(result),
predicate: map_expr_opt!(predicate),
},
Statement::SubgroupGather {
mode,
argument,
result,
} => match mode {
GatherMode::BroadcastFirst => {
return Statement::SubgroupGather {
mode: *mode,
argument: map_expr!(argument),
result: map_expr!(result),
}
}
GatherMode::Broadcast(hnd) => {
return Statement::SubgroupGather {
mode: GatherMode::Broadcast(map_expr!(hnd)),
argument: map_expr!(argument),
result: map_expr!(result),
}
}
GatherMode::Shuffle(hnd) => {
return Statement::SubgroupGather {
mode: GatherMode::Shuffle(map_expr!(hnd)),
argument: map_expr!(argument),
result: map_expr!(result),
}
}
GatherMode::ShuffleDown(hnd) => {
return Statement::SubgroupGather {
mode: GatherMode::ShuffleDown(map_expr!(hnd)),
argument: map_expr!(argument),
result: map_expr!(result),
}
}
GatherMode::ShuffleUp(hnd) => {
return Statement::SubgroupGather {
mode: GatherMode::ShuffleUp(map_expr!(hnd)),
argument: map_expr!(argument),
result: map_expr!(result),
}
}
GatherMode::ShuffleXor(hnd) => {
return Statement::SubgroupGather {
mode: GatherMode::ShuffleXor(map_expr!(hnd)),
argument: map_expr!(argument),
result: map_expr!(result),
}
}
},
Statement::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
} => Statement::SubgroupCollectiveOperation {
op: *op,
collective_op: *collective_op,
argument: map_expr!(argument),
result: map_expr!(result),
},
}
})
.collect();
Expand Down Expand Up @@ -594,10 +654,14 @@ impl<'a> DerivedModule<'a> {
committed: *committed,
}
}
// TODO: Probably wrong
Expression::Override(_) => expr.clone(),
Expression::SubgroupBallotResult => todo!(),
Expression::SubgroupOperationResult { ty } => todo!(),
Expression::Override(_) => {
// TODO: unsure if this is correct
expr.clone()
}
Expression::SubgroupBallotResult => expr.clone(),
Expression::SubgroupOperationResult { ty } => Expression::SubgroupOperationResult {
ty: self.import_type(ty),
},
};

if !non_emitting_only || is_external {
Expand Down Expand Up @@ -754,7 +818,7 @@ impl<'a> From<DerivedModule<'a>> for naga::Module {
types: derived.types,
constants: derived.constants,
global_variables: derived.globals,
// TODO: Maybe wrong
// TODO: Need to also include override expressions
global_expressions: Rc::try_unwrap(derived.const_expressions)
.unwrap()
.into_inner(),
Expand Down

0 comments on commit 4d16ed6

Please sign in to comment.