Skip to content

Commit

Permalink
[spv/msl/hlsl-out] support pipeline constant value replacements
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Feb 12, 2024
1 parent b9555f7 commit 28b79ab
Show file tree
Hide file tree
Showing 24 changed files with 463 additions and 15 deletions.
11 changes: 11 additions & 0 deletions naga/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,17 @@ impl<T> Arena<T> {
.map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) })
}

/// Drains the arena, returning an iterator over the items stored.
pub fn drain(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, T, Span)> {
let arena = std::mem::take(self);
arena
.data
.into_iter()
.zip(arena.span_info.into_iter())
.enumerate()
.map(|(i, (v, span))| unsafe { (Handle::from_usize_unchecked(i), v, span) })
}

/// Returns a iterator over the items stored in this arena,
/// returning both the item's handle and a mutable reference to it.
pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, &mut T)> {
Expand Down
6 changes: 6 additions & 0 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,12 @@ impl<'a, W: Write> Writer<'a, W> {
pipeline_options: &'a PipelineOptions,
policies: proc::BoundsCheckPolicies,
) -> Result<Self, Error> {
if !module.overrides.is_empty() {
return Err(Error::Custom(
"Pipeline constants are not yet supported for this back-end".to_string(),
));
}

// Check if the requested version is supported
if !options.version.is_supported() {
log::error!("Version {}", options.version);
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ pub enum Error {
Unimplemented(String), // TODO: Error used only during development
#[error("{0}")]
Custom(String),
#[error(transparent)]
PipelineConstant(#[from] back::pipeline_constants::PipelineConstantError),
}

#[derive(Default)]
Expand Down
6 changes: 5 additions & 1 deletion naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
&mut self,
module: &Module,
module_info: &valid::ModuleInfo,
_pipeline_options: &PipelineOptions,
pipeline_options: &PipelineOptions,
) -> Result<super::ReflectionInfo, Error> {
let module =
back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?;
let module = module.as_ref();

self.reset(module);

// Write special constants, if needed
Expand Down
8 changes: 8 additions & 0 deletions naga/src/back/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ pub mod spv;
#[cfg(feature = "wgsl-out")]
pub mod wgsl;

#[cfg(any(
feature = "hlsl-out",
feature = "msl-out",
feature = "spv-out",
feature = "glsl-out"
))]
mod pipeline_constants;

const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
const INDENT: &str = " ";
const BAKE_PREFIX: &str = "_e";
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ pub enum Error {
UnsupportedArrayOfType(Handle<crate::Type>),
#[error("ray tracing is not supported prior to MSL 2.3")]
UnsupportedRayTracing,
#[error(transparent)]
PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError),
}

#[derive(Clone, Debug, PartialEq, thiserror::Error)]
Expand Down
4 changes: 4 additions & 0 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3119,6 +3119,10 @@ impl<W: Write> Writer<W> {
options: &Options,
pipeline_options: &PipelineOptions,
) -> Result<TranslationInfo, Error> {
let module =
back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?;
let module = module.as_ref();

self.names.clear();
self.namer.reset(
module,
Expand Down
213 changes: 213 additions & 0 deletions naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
use super::PipelineConstants;
use crate::{Constant, Expression, Literal, Module, Scalar, Span, TypeInner};
use std::borrow::Cow;
use thiserror::Error;

#[derive(Error, Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub enum PipelineConstantError {
#[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
MissingValue(String),
#[error("Source f64 value needs to be finite (NaNs and Inifinites are not allowed) for number destinations")]
SrcNeedsToBeFinite,
#[error("Source f64 value doesn't fit in destination")]
DstRangeTooSmall,
}

pub(super) fn process_overrides<'a>(
module: &'a Module,
pipeline_constants: &PipelineConstants,
) -> Result<Cow<'a, Module>, PipelineConstantError> {
if module.overrides.is_empty() {
return Ok(Cow::Borrowed(module));
}

let mut module = module.clone();

for (_handle, override_, span) in module.overrides.drain() {
let key = if let Some(id) = override_.id {
Cow::Owned(id.to_string())
} else if let Some(ref name) = override_.name {
Cow::Borrowed(name)
} else {
unreachable!();
};
let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
let literal = match module.types[override_.ty].inner {
TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
_ => unreachable!(),
};
module
.const_expressions
.append(Expression::Literal(literal), Span::UNDEFINED)
} else if let Some(init) = override_.init {
init
} else {
return Err(PipelineConstantError::MissingValue(key.to_string()));
};
let constant = Constant {
name: override_.name,
ty: override_.ty,
init,
};
module.constants.append(constant, span);
}

Ok(Cow::Owned(module))
}

fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
// note that in rust 0.0 == -0.0
match scalar {
Scalar::BOOL => {
// https://webidl.spec.whatwg.org/#js-boolean
let value = value != 0.0 && !value.is_nan();
Ok(Literal::Bool(value))
}
Scalar::I32 => {
// https://webidl.spec.whatwg.org/#js-long
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}

let value = value.trunc();
if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
return Err(PipelineConstantError::DstRangeTooSmall);
}

let value = value as i32;
Ok(Literal::I32(value))
}
Scalar::U32 => {
// https://webidl.spec.whatwg.org/#js-unsigned-long
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}

let value = value.trunc();
if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
return Err(PipelineConstantError::DstRangeTooSmall);
}

let value = value as u32;
Ok(Literal::U32(value))
}
Scalar::F32 => {
// https://webidl.spec.whatwg.org/#js-float
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}

let value = value as f32;
if !value.is_finite() {
return Err(PipelineConstantError::DstRangeTooSmall);
}

Ok(Literal::F32(value))
}
Scalar::F64 => {
// https://webidl.spec.whatwg.org/#js-double
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}

Ok(Literal::F64(value))
}
_ => unreachable!(),
}
}

#[test]
fn test_map_value_to_literal() {
let bool_test_cases = [
(0.0, false),
(-0.0, false),
(f64::NAN, false),
(1.0, true),
(f64::INFINITY, true),
(f64::NEG_INFINITY, true),
];
for (value, out) in bool_test_cases {
let res = Ok(Literal::Bool(out));
assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
}

for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
assert_eq!(map_value_to_literal(value, scalar), res);
}
}

// i32
assert_eq!(
map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
Ok(Literal::I32(i32::MIN))
);
assert_eq!(
map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
Ok(Literal::I32(i32::MAX))
);
assert_eq!(
map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
Err(PipelineConstantError::DstRangeTooSmall)
);
assert_eq!(
map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
Err(PipelineConstantError::DstRangeTooSmall)
);

// u32
assert_eq!(
map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
Ok(Literal::U32(u32::MIN))
);
assert_eq!(
map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
Ok(Literal::U32(u32::MAX))
);
assert_eq!(
map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
Err(PipelineConstantError::DstRangeTooSmall)
);
assert_eq!(
map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
Err(PipelineConstantError::DstRangeTooSmall)
);

// f32
assert_eq!(
map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
Ok(Literal::F32(f32::MIN))
);
assert_eq!(
map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
Ok(Literal::F32(f32::MAX))
);
assert_eq!(
map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
Ok(Literal::F32(f32::MIN))
);
assert_eq!(
map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
Ok(Literal::F32(f32::MAX))
);
assert_eq!(
map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
Err(PipelineConstantError::DstRangeTooSmall)
);
assert_eq!(
map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
Err(PipelineConstantError::DstRangeTooSmall)
);

// f64
assert_eq!(
map_value_to_literal(f64::MIN, Scalar::F64),
Ok(Literal::F64(f64::MIN))
);
assert_eq!(
map_value_to_literal(f64::MAX, Scalar::F64),
Ok(Literal::F64(f64::MAX))
);
}
2 changes: 2 additions & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ pub enum Error {
FeatureNotImplemented(&'static str),
#[error("module is not validated properly: {0}")]
Validation(&'static str),
#[error(transparent)]
PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError),
}

#[derive(Default)]
Expand Down
10 changes: 10 additions & 0 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2018,6 +2018,16 @@ impl Writer {
debug_info: &Option<DebugInfo>,
words: &mut Vec<Word>,
) -> Result<(), Error> {
let ir_module = if let Some(pipeline_options) = pipeline_options {
crate::back::pipeline_constants::process_overrides(
ir_module,
&pipeline_options.constants,
)?
} else {
std::borrow::Cow::Borrowed(ir_module)
};
let ir_module = ir_module.as_ref();

self.reset();

// Try to find the entry point and corresponding index
Expand Down
6 changes: 6 additions & 0 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ impl<W: Write> Writer<W> {
}

pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
if !module.overrides.is_empty() {
return Err(Error::Unimplemented(
"Pipeline constants are not yet supported for this back-end".to_string(),
));
}

self.reset(module);

// Save all ep result types
Expand Down
1 change: 1 addition & 0 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ impl ExpressionConstnessTracker {
}

#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum ConstantEvaluatorError {
#[error("Constants cannot access function arguments")]
FunctionArg,
Expand Down
15 changes: 14 additions & 1 deletion naga/src/valid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ pub enum ConstantError {

#[derive(Clone, Debug, thiserror::Error)]
pub enum OverrideError {
#[error("Override name and ID are missing")]
MissingNameAndID,
#[error("The type doesn't match the override")]
InvalidType,
#[error("The type is not constructible")]
Expand Down Expand Up @@ -351,14 +353,25 @@ impl Validator {
) -> Result<(), OverrideError> {
let o = &gctx.overrides[handle];

if o.name.is_none() && o.id.is_none() {
return Err(OverrideError::MissingNameAndID);
}

let type_info = &self.types[o.ty.index()];
if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
return Err(OverrideError::NonConstructibleType);
}

let decl_ty = &gctx.types[o.ty].inner;
match decl_ty {
&crate::TypeInner::Scalar(_) => {}
&crate::TypeInner::Scalar(scalar) => match scalar {
crate::Scalar::BOOL
| crate::Scalar::I32
| crate::Scalar::U32
| crate::Scalar::F32
| crate::Scalar::F64 => {}
_ => return Err(OverrideError::TypeNotScalar),
},
_ => return Err(OverrideError::TypeNotScalar),
}

Expand Down
11 changes: 11 additions & 0 deletions naga/tests/in/overrides.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
(
spv: (
version: (1, 0),
separate_entry_points: true,
),
pipeline_constants: {
"0": NaN,
"1300": 1.1,
"depth": 2.3,
}
)
Loading

0 comments on commit 28b79ab

Please sign in to comment.