Skip to content

Commit

Permalink
Implements all frontends and backends.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lichtso committed Oct 14, 2023
1 parent a50d65a commit a3b6b6d
Show file tree
Hide file tree
Showing 24 changed files with 977 additions and 231 deletions.
74 changes: 68 additions & 6 deletions src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,23 +287,85 @@ impl StatementGraph {
"SubgroupBallot"
}
S::SubgroupCollectiveOperation {
ref op,
ref collective_op,
op,
collective_op,
argument,
result,
} => {
self.dependencies.push((id, argument, "arg"));
self.emits.push((id, result));
"SubgroupCollectiveOperation" // FIXME
match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
"SubgroupAll"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
"SubgroupAny"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
"SubgroupAdd"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
"SubgroupMul"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
"SubgroupMax"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
"SubgroupMin"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
"SubgroupAnd"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
"SubgroupOr"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
"SubgroupXor"
}
(
crate::CollectiveOperation::ExclusiveScan,
crate::SubgroupOperation::Add,
) => "SubgroupPrefixExclusiveAdd",
(
crate::CollectiveOperation::ExclusiveScan,
crate::SubgroupOperation::Mul,
) => "SubgroupPrefixExclusiveMul",
(
crate::CollectiveOperation::InclusiveScan,
crate::SubgroupOperation::Add,
) => "SubgroupPrefixInclusiveAdd",
(
crate::CollectiveOperation::InclusiveScan,
crate::SubgroupOperation::Mul,
) => "SubgroupPrefixInclusiveMul",
_ => unimplemented!(),
}
}
S::SubgroupBroadcast {
ref mode,
S::SubgroupGather {
mode,
argument,
result,
} => {
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
self.dependencies.push((id, index, "index"))
}
}
self.dependencies.push((id, argument, "arg"));
self.emits.push((id, result));
"SubgroupBroadcast" // FIXME
match mode {
crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst",
crate::GatherMode::Broadcast(_) => "SubgroupBroadcast",
crate::GatherMode::Shuffle(_) => "SubgroupShuffle",
crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown",
crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp",
crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor",
}
}
};
// Set the last node to the merge node
Expand Down
109 changes: 100 additions & 9 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2251,22 +2251,111 @@ impl<'a, W: Write> Writer<'a, W> {
Some(predicate) => self.write_expr(predicate, ctx)?,
None => write!(self.out, "true")?,
}
write!(self.out, ");")?;
writeln!(self.out, ");")?;
}
Statement::SubgroupCollectiveOperation {
ref op,
ref collective_op,
op,
collective_op,
argument,
result,
} => {
unimplemented!(); // FIXME:
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
write!(self.out, "subgroupAll(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
write!(self.out, "subgroupAny(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupAdd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupMul(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
write!(self.out, "subgroupMax(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
write!(self.out, "subgroupMin(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
write!(self.out, "subgroupAnd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
write!(self.out, "subgroupOr(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
write!(self.out, "subgroupXor(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupExclusiveAdd(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupExclusiveMul(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupInclusiveAdd(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupInclusiveMul(")?
}
_ => unimplemented!(),
}
self.write_expr(argument, ctx)?;
writeln!(self.out, ");")?;
}
Statement::SubgroupBroadcast {
ref mode,
Statement::SubgroupGather {
mode,
argument,
result,
} => {
unimplemented!(); // FIXME
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

match mode {
crate::GatherMode::BroadcastFirst => {
write!(self.out, "subgroupBroadcastFirst(")?;
}
crate::GatherMode::Broadcast(_) => {
write!(self.out, "subgroupBroadcast(")?;
}
crate::GatherMode::Shuffle(_) => {
write!(self.out, "subgroupShuffle(")?;
}
crate::GatherMode::ShuffleDown(_) => {
write!(self.out, "subgroupShuffleDown(")?;
}
crate::GatherMode::ShuffleUp(_) => {
write!(self.out, "subgroupShuffleUp(")?;
}
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "subgroupShuffleXor(")?;
}
}
self.write_expr(argument, ctx)?;
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
write!(self.out, ", ")?;
self.write_expr(index, ctx)?;
}
}
writeln!(self.out, ");")?;
}
}

Expand Down Expand Up @@ -4026,7 +4115,7 @@ impl<'a, W: Write> Writer<'a, W> {
writeln!(self.out, "{level}memoryBarrierShared();")?;
}
if flags.contains(crate::Barrier::SUB_GROUP) {
unimplemented!() // FIXME
writeln!(self.out, "{level}subgroupMemoryBarrier();")?;
}
writeln!(self.out, "{level}barrier();")?;
Ok(())
Expand Down Expand Up @@ -4205,8 +4294,10 @@ const fn glsl_built_in(
Bi::WorkGroupSize => "gl_WorkGroupSize",
Bi::NumWorkGroups => "gl_NumWorkGroups",
// subgroup
Bi::SubgroupInvocationId => "gl_SubgroupInvocationID",
Bi::NumSubgroups => "gl_NumSubgroups",
Bi::SubgroupId => "gl_SubgroupID",
Bi::SubgroupSize => "gl_SubgroupSize",
Bi::SubgroupInvocationId => "gl_SubgroupInvocationID",
}
}

Expand Down
16 changes: 10 additions & 6 deletions src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,16 @@ impl crate::BuiltIn {
// to this field will get replaced with references to `SPECIAL_CBUF_VAR`
// in `Writer::write_expr`.
Self::NumWorkGroups => "SV_GroupID",

Self::SubgroupInvocationId
| Self::SubgroupSize
| Self::BaseInstance
| Self::BaseVertex
| Self::WorkGroupSize => return Err(Error::Unimplemented(format!("builtin {self:?}"))),
Self::SubgroupSize => "WaveGetLaneCount()",
Self::SubgroupInvocationId => "WaveGetLaneIndex()",
Self::NumSubgroups => {
// FIXME
"(numthreads[0] * numthreads[1] * numthreads[2] / WaveGetLaneCount())"
}
Self::SubgroupId => "(SV_GroupIndex / WaveGetLaneCount())", // FIXME
Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
return Err(Error::Unimplemented(format!("builtin {self:?}")))
}
Self::PointSize | Self::ViewIndex | Self::PointCoord => {
return Err(Error::Custom(format!("Unsupported builtin {self:?}")))
}
Expand Down
110 changes: 102 additions & 8 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2010,7 +2010,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;

let name = format!("{}{}", back::BAKE_PREFIX, result.index());
write!(self.out, "const uint4 {name} = ")?;
self.named_expressions.insert(result, name);
Expand All @@ -2023,19 +2022,114 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, ");")?;
}
Statement::SubgroupCollectiveOperation {
ref op,
ref collective_op,
op,
collective_op,
argument,
result,
} => {
unimplemented!(); // FIXME
write!(self.out, "{level}")?;
write!(self.out, "const ")?;
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
match func_ctx.info[result].ty {
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
proc::TypeResolution::Value(ref value) => {
self.write_value_type(module, value)?
}
};
write!(self.out, " {name} = ")?;
self.named_expressions.insert(result, name);

match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
write!(self.out, "WaveActiveAllTrue(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
write!(self.out, "WaveActiveAnyTrue(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
write!(self.out, "WaveActiveSum(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
write!(self.out, "WaveActiveProduct(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
write!(self.out, "WaveActiveMax(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
write!(self.out, "WaveActiveMin(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
write!(self.out, "WaveActiveBitAnd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
write!(self.out, "WaveActiveBitOr(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
write!(self.out, "WaveActiveBitXor(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "WavePrefixSum(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "WavePrefixProduct(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
self.write_expr(module, argument, func_ctx)?;
write!(self.out, " + WavePrefixSum(")?;
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
self.write_expr(module, argument, func_ctx)?;
write!(self.out, " * WavePrefixProduct(")?;
}
_ => unimplemented!(),
}
self.write_expr(module, argument, func_ctx)?;
writeln!(self.out, ");")?;
}
Statement::SubgroupBroadcast {
ref mode,
Statement::SubgroupGather {
mode,
argument,
result,
} => {
unimplemented!(); // FIXME
write!(self.out, "{level}")?;
write!(self.out, "const ")?;
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
match func_ctx.info[result].ty {
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
proc::TypeResolution::Value(ref value) => {
self.write_value_type(module, value)?
}
};
write!(self.out, " {name} = ")?;
self.named_expressions.insert(result, name);

if matches!(mode, crate::GatherMode::BroadcastFirst) {
write!(self.out, "WaveReadLaneFirst(")?;
self.write_expr(module, argument, func_ctx)?;
} else {
write!(self.out, "WaveReadLaneAt(")?;
self.write_expr(module, argument, func_ctx)?;
write!(self.out, ", ")?;
match mode {
crate::GatherMode::BroadcastFirst => unreachable!(),
crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => {
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleDown(index) => {
write!(self.out, "WaveGetLaneIndex() + ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleUp(index) => {
write!(self.out, "WaveGetLaneIndex() - ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleXor(index) => {
write!(self.out, "WaveGetLaneIndex() ^ ")?;
self.write_expr(module, index, func_ctx)?;
}
}
}
writeln!(self.out, ");")?;
}
}

Expand Down Expand Up @@ -3289,7 +3383,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
}
if barrier.contains(crate::Barrier::SUB_GROUP) {
unimplemented!() // FIXME
// Does not exist in DirectX
}
Ok(())
}
Expand Down
Loading

0 comments on commit a3b6b6d

Please sign in to comment.