Skip to content

Commit

Permalink
[spv/msl/hlsl-out] support pipeline constant value replacements
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Jan 5, 2024
1 parent 82ebbfd commit 340768e
Show file tree
Hide file tree
Showing 23 changed files with 368 additions and 15 deletions.
9 changes: 9 additions & 0 deletions naga/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,15 @@ impl<T> Arena<T> {
.map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) })
}

/// Drains the arena, returning an iterator over the items stored.
pub fn drain(self) -> impl DoubleEndedIterator<Item = (Handle<T>, T, Span)> {
self.data
.into_iter()
.zip(self.span_info.into_iter())
.enumerate()
.map(|(i, (v, span))| unsafe { (Handle::from_usize_unchecked(i), v, span) })
}

/// Returns a iterator over the items stored in this arena,
/// returning both the item's handle and a mutable reference to it.
pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, &mut T)> {
Expand Down
6 changes: 6 additions & 0 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,12 @@ impl<'a, W: Write> Writer<'a, W> {
pipeline_options: &'a PipelineOptions,
policies: proc::BoundsCheckPolicies,
) -> Result<Self, Error> {
if !module.overrides.is_empty() {
return Err(Error::Custom(
"Pipeline constants are not yet supported for this back-end".to_string(),
));
}

// Check if the requested version is supported
if !options.version.is_supported() {
log::error!("Version {}", options.version);
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ pub enum Error {
Unimplemented(String), // TODO: Error used only during development
#[error("{0}")]
Custom(String),
#[error(transparent)]
PipelineConstant(#[from] back::pipeline_constants::PipelineConstantError),
}

#[derive(Default)]
Expand Down
6 changes: 5 additions & 1 deletion naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
&mut self,
module: &Module,
module_info: &valid::ModuleInfo,
_pipeline_options: &PipelineOptions,
pipeline_options: &PipelineOptions,
) -> Result<super::ReflectionInfo, Error> {
let module =
back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?;
let module = module.as_ref();

self.reset(module);

// Write special constants, if needed
Expand Down
8 changes: 8 additions & 0 deletions naga/src/back/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ pub mod spv;
#[cfg(feature = "wgsl-out")]
pub mod wgsl;

#[cfg(any(
feature = "hlsl-out",
feature = "msl-out",
feature = "spv-out",
feature = "glsl-out"
))]
pub mod pipeline_constants;

const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
const INDENT: &str = " ";
const BAKE_PREFIX: &str = "_e";
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ pub enum Error {
UnsupportedArrayOfType(Handle<crate::Type>),
#[error("ray tracing is not supported prior to MSL 2.3")]
UnsupportedRayTracing,
#[error(transparent)]
PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError),
}

#[derive(Clone, Debug, PartialEq, thiserror::Error)]
Expand Down
4 changes: 4 additions & 0 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3112,6 +3112,10 @@ impl<W: Write> Writer<W> {
options: &Options,
pipeline_options: &PipelineOptions,
) -> Result<TranslationInfo, Error> {
let module =
back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?;
let module = module.as_ref();

self.names.clear();
self.namer.reset(
module,
Expand Down
118 changes: 118 additions & 0 deletions naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use super::PipelineConstants;
use crate::{Arena, Constant, Expression, Literal, Module, Scalar, Span, TypeInner};
use std::borrow::Cow;
use thiserror::Error;

#[derive(Error, Debug, Clone)]
pub enum PipelineConstantError {
#[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
MissingValue(String),
#[error("Source f64 value needs to be finite (NaNs and Inifinites are not allowed) for number destinations")]
SrcNeedsToBeFinite,
#[error("Source f64 value doesn't fit in destination")]
DstRangeTooSmall,
}

pub(super) fn process_overrides<'a>(
module: &'a Module,
pipeline_constants: &PipelineConstants,
) -> Result<Cow<'a, Module>, PipelineConstantError> {
if module.overrides.is_empty() {
return Ok(Cow::Borrowed(module));
}

let mut module = module.clone();
let overrides = std::mem::replace(&mut module.overrides, Arena::new());

for (_handle, override_, span) in overrides.drain() {
let key = if let Some(id) = override_.id {
Cow::Owned(id.to_string())
} else if let Some(ref name) = override_.name {
Cow::Borrowed(name)
} else {
unreachable!();
};
let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
let literal = match module.types[override_.ty].inner {
TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
_ => unreachable!(),
};
module
.const_expressions
.append(Expression::Literal(literal), Span::UNDEFINED)
} else if let Some(init) = override_.init {
init
} else {
return Err(PipelineConstantError::MissingValue(key.to_string()));
};
let constant = Constant {
name: override_.name,
ty: override_.ty,
init,
};
module.constants.append(constant, span);
}

Ok(Cow::Owned(module))
}

fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
// note that in rust 0.0 == -0.0
match scalar {
Scalar::BOOL => {
// https://webidl.spec.whatwg.org/#js-boolean
let value = value != 0.0 && !value.is_nan();
Ok(Literal::Bool(value))
}
Scalar::I32 => {
// https://webidl.spec.whatwg.org/#js-long
if value.is_finite() {
let value = value.abs().floor() * value.signum();
if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
Err(PipelineConstantError::DstRangeTooSmall)
} else {
let value = value as i32;
Ok(Literal::I32(value))
}
} else {
Err(PipelineConstantError::SrcNeedsToBeFinite)
}
}
Scalar::U32 => {
// https://webidl.spec.whatwg.org/#js-unsigned-long
if value.is_finite() {
let value = value.abs().floor() * value.signum();
if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
Err(PipelineConstantError::DstRangeTooSmall)
} else {
let value = value as u32;
Ok(Literal::U32(value))
}
} else {
Err(PipelineConstantError::SrcNeedsToBeFinite)
}
}
Scalar::F32 => {
// https://webidl.spec.whatwg.org/#js-float
if value.is_finite() {
let value = value as f32;
if value.is_finite() {
Ok(Literal::F32(value))
} else {
Err(PipelineConstantError::DstRangeTooSmall)
}
} else {
Err(PipelineConstantError::SrcNeedsToBeFinite)
}
}
Scalar::F64 => {
// https://webidl.spec.whatwg.org/#js-double
if value.is_finite() {
Ok(Literal::F64(value))
} else {
Err(PipelineConstantError::SrcNeedsToBeFinite)
}
}
_ => unreachable!(),
}
}
2 changes: 2 additions & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ pub enum Error {
FeatureNotImplemented(&'static str),
#[error("module is not validated properly: {0}")]
Validation(&'static str),
#[error(transparent)]
PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError),
}

#[derive(Default)]
Expand Down
10 changes: 10 additions & 0 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2007,6 +2007,16 @@ impl Writer {
debug_info: &Option<DebugInfo>,
words: &mut Vec<Word>,
) -> Result<(), Error> {
let ir_module = if let Some(pipeline_options) = pipeline_options {
crate::back::pipeline_constants::process_overrides(
ir_module,
&pipeline_options.constants,
)?
} else {
std::borrow::Cow::Borrowed(ir_module)
};
let ir_module = ir_module.as_ref();

self.reset();

// Try to find the entry point and corresponding index
Expand Down
6 changes: 6 additions & 0 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ impl<W: Write> Writer<W> {
}

pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
if !module.overrides.is_empty() {
return Err(Error::Unimplemented(
"Pipeline constants are not yet supported for this back-end".to_string(),
));
}

self.reset(module);

// Save all ep result types
Expand Down
15 changes: 14 additions & 1 deletion naga/src/valid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ pub enum ConstantError {

#[derive(Clone, Debug, thiserror::Error)]
pub enum OverrideError {
#[error("Override name and ID are missing")]
MissingNameAndID,
#[error("The type doesn't match the override")]
InvalidType,
#[error("The type is not constructible")]
Expand Down Expand Up @@ -351,14 +353,25 @@ impl Validator {
) -> Result<(), OverrideError> {
let o = &gctx.overrides[handle];

if o.name.is_none() && o.id.is_none() {
return Err(OverrideError::MissingNameAndID);
}

let type_info = &self.types[o.ty.index()];
if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
return Err(OverrideError::NonConstructibleType);
}

let decl_ty = &gctx.types[o.ty].inner;
match decl_ty {
&crate::TypeInner::Scalar(_) => {}
&crate::TypeInner::Scalar(scalar) => match scalar {
crate::Scalar::BOOL
| crate::Scalar::I32
| crate::Scalar::U32
| crate::Scalar::F32
| crate::Scalar::F64 => {}
_ => return Err(OverrideError::TypeNotScalar),
},
_ => return Err(OverrideError::TypeNotScalar),
}

Expand Down
11 changes: 11 additions & 0 deletions naga/tests/in/overrides.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
(
spv: (
version: (1, 0),
separate_entry_points: true,
),
pipeline_constants: {
"0": NaN,
"1300": 1.1,
"depth": 2.3,
}
)
3 changes: 3 additions & 0 deletions naga/tests/in/overrides.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@
// overridable constant.

override inferred_f32 = 2.718;

@compute @workgroup_size(1)
fn main() {}
17 changes: 16 additions & 1 deletion naga/tests/out/analysis/overrides.info.ron
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,22 @@
("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"),
],
functions: [],
entry_points: [],
entry_points: [
(
flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"),
available_stages: ("VERTEX | FRAGMENT | COMPUTE"),
uniformity: (
non_uniform_result: None,
requirements: (""),
),
may_kill: false,
sampling_set: [],
global_uses: [],
expressions: [],
sampling: [],
dual_source_blending: false,
),
],
const_expression_types: [
Value(Scalar((
kind: Bool,
Expand Down
12 changes: 12 additions & 0 deletions naga/tests/out/hlsl/overrides.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
static const bool has_point_light = false;
static const float specular_param = 2.3;
static const float gain = 1.1;
static const float width = 0.0;
static const float depth = 2.3;
static const float inferred_f32_ = 2.718;

[numthreads(1, 1, 1)]
void main()
{
return;
}
12 changes: 12 additions & 0 deletions naga/tests/out/hlsl/overrides.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_5_1",
),
],
)
22 changes: 21 additions & 1 deletion naga/tests/out/ir/overrides.compact.ron
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,25 @@
Literal(F32(2.718)),
],
functions: [],
entry_points: [],
entry_points: [
(
name: "main",
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
function: (
name: Some("main"),
arguments: [],
result: None,
local_variables: [],
expressions: [],
named_expressions: {},
body: [
Return(
value: None,
),
],
),
),
],
)
Loading

0 comments on commit 340768e

Please sign in to comment.