Skip to content

Commit

Permalink
[wgsl-in] Handle modf and frexp (#2454)
Browse files Browse the repository at this point in the history
  • Loading branch information
fornwall authored Sep 2, 2023
1 parent f49314d commit 5329aa2
Show file tree
Hide file tree
Showing 26 changed files with 843 additions and 205 deletions.
3 changes: 3 additions & 0 deletions src/back/glsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,4 +477,7 @@ pub const RESERVED_KEYWORDS: &[&str] = &[
// entry point name (should not be shadowed)
//
"main",
// Naga utilities:
super::MODF_FUNCTION,
super::FREXP_FUNCTION,
];
54 changes: 52 additions & 2 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320];
/// of detail for bounds checking in `ImageLoad`
const CLAMPED_LOD_SUFFIX: &str = "_clamped_lod";

pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

/// Mapping between resources and bindings.
pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, u8>;

Expand Down Expand Up @@ -631,6 +634,53 @@ impl<'a, W: Write> Writer<'a, W> {
}
}

// Write functions to create special types.
for (type_key, struct_ty) in self.module.special_types.predeclared_types.iter() {
match type_key {
&crate::PredeclaredType::ModfResult { size, width }
| &crate::PredeclaredType::FrexpResult { size, width } => {
let arg_type_name_owner;
let arg_type_name = if let Some(size) = size {
arg_type_name_owner =
format!("{}vec{}", if width == 8 { "d" } else { "" }, size as u8);
&arg_type_name_owner
} else if width == 8 {
"double"
} else {
"float"
};

let other_type_name_owner;
let (defined_func_name, called_func_name, other_type_name) =
if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
(MODF_FUNCTION, "modf", arg_type_name)
} else {
let other_type_name = if let Some(size) = size {
other_type_name_owner = format!("ivec{}", size as u8);
&other_type_name_owner
} else {
"int"
};
(FREXP_FUNCTION, "frexp", other_type_name)
};

let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(self.out)?;
writeln!(
self.out,
"{} {defined_func_name}({arg_type_name} arg) {{
{other_type_name} other;
{arg_type_name} fract = {called_func_name}(arg, other);
return {}(fract, other);
}}",
struct_name, struct_name
)?;
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
}
}

// Write all named constants
let mut constants = self
.module
Expand Down Expand Up @@ -2997,8 +3047,8 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Round => "roundEven",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Modf => MODF_FUNCTION,
Mf::Frexp => FREXP_FUNCTION,
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Expand Down
53 changes: 53 additions & 0 deletions src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,59 @@ impl<'a, W: Write> super::Writer<'a, W> {
Ok(())
}

/// Write functions to create special types.
pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult {
for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
match type_key {
&crate::PredeclaredType::ModfResult { size, width }
| &crate::PredeclaredType::FrexpResult { size, width } => {
let arg_type_name_owner;
let arg_type_name = if let Some(size) = size {
arg_type_name_owner = format!(
"{}{}",
if width == 8 { "double" } else { "float" },
size as u8
);
&arg_type_name_owner
} else if width == 8 {
"double"
} else {
"float"
};

let (defined_func_name, called_func_name, second_field_name, sign_multiplier) =
if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
(super::writer::MODF_FUNCTION, "modf", "whole", "")
} else {
(
super::writer::FREXP_FUNCTION,
"frexp",
"exp_",
"sign(arg) * ",
)
};

let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(
self.out,
"{struct_name} {defined_func_name}({arg_type_name} arg) {{
{arg_type_name} other;
{struct_name} result;
result.fract = {sign_multiplier}{called_func_name}(arg, other);
result.{second_field_name} = other;
return result;
}}"
)?;
writeln!(self.out)?;
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
}
}

Ok(())
}

/// Helper function that writes compose wrapped functions
pub(super) fn write_wrapped_compose_functions(
&mut self,
Expand Down
3 changes: 3 additions & 0 deletions src/back/hlsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,9 @@ pub const RESERVED: &[&str] = &[
"TextureBuffer",
"ConstantBuffer",
"RayQuery",
// Naga utilities
super::writer::MODF_FUNCTION,
super::writer::FREXP_FUNCTION,
];

// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254
Expand Down
9 changes: 7 additions & 2 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ const SPECIAL_BASE_VERTEX: &str = "base_vertex";
const SPECIAL_BASE_INSTANCE: &str = "base_instance";
const SPECIAL_OTHER: &str = "other";

pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

struct EpStructMember {
name: String,
ty: Handle<crate::Type>,
Expand Down Expand Up @@ -244,6 +247,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}

self.write_special_functions(module)?;

self.write_wrapped_compose_functions(module, &module.const_expressions)?;

// Write all named constants
Expand Down Expand Up @@ -2675,8 +2680,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Round => Function::Regular("round"),
Mf::Fract => Function::Regular("frac"),
Mf::Trunc => Function::Regular("trunc"),
Mf::Modf => Function::Regular("modf"),
Mf::Frexp => Function::Regular("frexp"),
Mf::Modf => Function::Regular(MODF_FUNCTION),
Mf::Frexp => Function::Regular(FREXP_FUNCTION),
Mf::Ldexp => Function::Regular("ldexp"),
// exponent
Mf::Exp => Function::Regular("exp"),
Expand Down
2 changes: 2 additions & 0 deletions src/back/msl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,6 @@ pub const RESERVED: &[&str] = &[
// Naga utilities
"DefaultConstructible",
"clamped_lod_e",
super::writer::FREXP_FUNCTION,
super::writer::MODF_FUNCTION,
];
61 changes: 59 additions & 2 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
const RAY_QUERY_FIELD_READY: &str = "ready";
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";

/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
///
/// The `sizes` slice determines whether this function writes a
Expand Down Expand Up @@ -1678,8 +1681,8 @@ impl<W: Write> Writer<W> {
Mf::Round => "rint",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Modf => MODF_FUNCTION,
Mf::Frexp => FREXP_FUNCTION,
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Expand Down Expand Up @@ -1813,6 +1816,9 @@ impl<W: Write> Writer<W> {
write!(self.out, "((")?;
self.put_expression(arg, context, false)?;
write!(self.out, ") * 57.295779513082322865)")?;
} else if fun == Mf::Modf || fun == Mf::Frexp {
write!(self.out, "{fun_name}")?;
self.put_call_parameters(iter::once(arg), context)?;
} else {
write!(self.out, "{NAMESPACE}::{fun_name}")?;
self.put_call_parameters(
Expand Down Expand Up @@ -3236,6 +3242,57 @@ impl<W: Write> Writer<W> {
}
}
}

// Write functions to create special types.
for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
match type_key {
&crate::PredeclaredType::ModfResult { size, width }
| &crate::PredeclaredType::FrexpResult { size, width } => {
let arg_type_name_owner;
let arg_type_name = if let Some(size) = size {
arg_type_name_owner = format!(
"{NAMESPACE}::{}{}",
if width == 8 { "double" } else { "float" },
size as u8
);
&arg_type_name_owner
} else if width == 8 {
"double"
} else {
"float"
};

let other_type_name_owner;
let (defined_func_name, called_func_name, other_type_name) =
if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
(MODF_FUNCTION, "modf", arg_type_name)
} else {
let other_type_name = if let Some(size) = size {
other_type_name_owner = format!("int{}", size as u8);
&other_type_name_owner
} else {
"int"
};
(FREXP_FUNCTION, "frexp", other_type_name)
};

let struct_name = &self.names[&NameKey::Type(*struct_ty)];

writeln!(self.out)?;
writeln!(
self.out,
"{} {defined_func_name}({arg_type_name} arg) {{
{other_type_name} other;
{arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
return {}{{ fract, other }};
}}",
struct_name, struct_name
)?;
}
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
}
}

Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,8 @@ impl<'w> BlockContext<'w> {
Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
Mf::Modf => MathOp::Ext(spirv::GLOp::Modf),
Mf::Frexp => MathOp::Ext(spirv::GLOp::Frexp),
Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct),
Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct),
Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
// geometry
Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
Expand Down
22 changes: 15 additions & 7 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ impl<W: Write> Writer<W> {
self.ep_results.clear();
}

fn is_builtin_wgsl_struct(&self, module: &Module, handle: Handle<crate::Type>) -> bool {
module
.special_types
.predeclared_types
.values()
.any(|t| *t == handle)
}

pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
self.reset(module);

Expand All @@ -109,13 +117,13 @@ impl<W: Write> Writer<W> {

// Write all structs
for (handle, ty) in module.types.iter() {
if let TypeInner::Struct {
ref members,
span: _,
} = ty.inner
{
self.write_struct(module, handle, members)?;
writeln!(self.out)?;
if let TypeInner::Struct { ref members, .. } = ty.inner {
{
if !self.is_builtin_wgsl_struct(module, handle) {
self.write_struct(module, handle, members)?;
writeln!(self.out)?;
}
}
}
}

Expand Down
Loading

0 comments on commit 5329aa2

Please sign in to comment.