From 6a494a83a6e27aa7670040984bc835d1b8cb39ac Mon Sep 17 00:00:00 2001 From: Teodor Tanasoaia <28601907+teoxoy@users.noreply.github.com> Date: Thu, 7 Dec 2023 20:19:43 +0100 Subject: [PATCH] [wgsl-in] add support for override declarations (#4793) Co-authored-by: Jim Blandy --- naga/src/back/dot/mod.rs | 1 + naga/src/back/glsl/mod.rs | 1 + naga/src/back/hlsl/writer.rs | 3 + naga/src/back/msl/writer.rs | 3 + naga/src/back/spv/block.rs | 3 + naga/src/back/wgsl/writer.rs | 3 + naga/src/compact/expressions.rs | 9 +++ naga/src/compact/functions.rs | 2 + naga/src/compact/mod.rs | 19 ++++++ naga/src/front/spv/function.rs | 2 + naga/src/front/spv/mod.rs | 3 +- naga/src/front/wgsl/error.rs | 17 ++++-- naga/src/front/wgsl/index.rs | 1 + naga/src/front/wgsl/lower/mod.rs | 70 +++++++++++++++++++-- naga/src/front/wgsl/parse/ast.rs | 9 +++ naga/src/front/wgsl/parse/mod.rs | 30 +++++++++ naga/src/front/wgsl/to_wgsl.rs | 1 + naga/src/lib.rs | 37 +++++++---- naga/src/proc/constant_evaluator.rs | 23 ++++++- naga/src/proc/mod.rs | 2 + naga/src/proc/typifier.rs | 3 + naga/src/valid/analyzer.rs | 3 +- naga/src/valid/expression.rs | 2 +- naga/src/valid/handles.rs | 39 +++++++++++- naga/src/valid/mod.rs | 57 +++++++++++++++++ naga/tests/in/overrides.wgsl | 14 +++++ naga/tests/out/analysis/overrides.info.ron | 26 ++++++++ naga/tests/out/ir/access.compact.ron | 1 + naga/tests/out/ir/access.ron | 1 + naga/tests/out/ir/collatz.compact.ron | 1 + naga/tests/out/ir/collatz.ron | 1 + naga/tests/out/ir/overrides.compact.ron | 71 ++++++++++++++++++++++ naga/tests/out/ir/overrides.ron | 71 ++++++++++++++++++++++ naga/tests/out/ir/shadow.compact.ron | 1 + naga/tests/out/ir/shadow.ron | 1 + naga/tests/snapshots.rs | 8 +++ naga/tests/wgsl_errors.rs | 4 +- 37 files changed, 515 insertions(+), 28 deletions(-) create mode 100644 naga/tests/in/overrides.wgsl create mode 100644 naga/tests/out/analysis/overrides.info.ron create mode 100644 naga/tests/out/ir/overrides.compact.ron create mode 100644 naga/tests/out/ir/overrides.ron diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 1556371df1..d128c855ca 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -404,6 +404,7 @@ fn write_function_expressions( let (label, color_id) = match *expression { E::Literal(_) => ("Literal".into(), 2), E::Constant(_) => ("Constant".into(), 2), + E::Override(_) => ("Override".into(), 2), E::ZeroValue(_) => ("ZeroValue".into(), 2), E::Compose { ref components, .. } => { payload = Some(Payload::Arguments(components)); diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index d38110c89b..2eba4ea40c 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2530,6 +2530,7 @@ impl<'a, W: Write> Writer<'a, W> { |writer, expr| writer.write_expr(expr, ctx), )?; } + Expression::Override(_) => return Err(Error::Custom("overrides are WIP".into())), // `Access` is applied to arrays, vectors and matrices and is written as indexing Expression::Access { base, index } => { self.write_expr(base, ctx)?; diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 42540ce557..6e0cd91d88 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2156,6 +2156,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } + Expression::Override(_) => { + return Err(Error::Unimplemented("overrides are WIP".into())) + } // All of the multiplication can be expressed as `mul`, // except vector * vector, which needs to use the "*" operator. Expression::Binary { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index f900add71e..7f2e37d83a 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1401,6 +1401,9 @@ impl Writer { |writer, context, expr| writer.put_expression(expr, context, true), )?; } + crate::Expression::Override(_) => { + return Err(Error::FeatureNotImplemented("overrides are WIP".into())) + } crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { // This is an acceptable place to generate a `ReadZeroSkipWrite` check. diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 6c96fa09e3..4eca34168c 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -239,6 +239,9 @@ impl<'w> BlockContext<'w> { let init = self.ir_module.constants[handle].init; self.writer.constant_ids[init.index()] } + crate::Expression::Override(_) => { + return Err(Error::FeatureNotImplemented("overrides are WIP")) + } crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), crate::Expression::Compose { ty, ref components } => { self.temp_list.clear(); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index c737934f5e..bd4d5f17d7 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1190,6 +1190,9 @@ impl Writer { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } + Expression::Override(_) => { + return Err(Error::Unimplemented("overrides are WIP".into())) + } Expression::FunctionArgument(pos) => { let name_key = func_ctx.argument_key(pos); let name = &self.names[&name_key]; diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index 301bbe3240..21c4c9cdc2 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -3,6 +3,7 @@ use crate::arena::{Arena, Handle}; pub struct ExpressionTracer<'tracer> { pub constants: &'tracer Arena, + pub overrides: &'tracer Arena, /// The arena in which we are currently tracing expressions. pub expressions: &'tracer Arena, @@ -88,6 +89,11 @@ impl<'tracer> ExpressionTracer<'tracer> { None => self.expressions_used.insert(init), } } + Ex::Override(_) => { + // All overrides are considered used by definition. We mark + // their types and initialization expressions as used in + // `compact::compact`, so we have no more work to do here. + } Ex::ZeroValue(ty) => self.types_used.insert(ty), Ex::Compose { ty, ref components } => { self.types_used.insert(ty); @@ -219,6 +225,9 @@ impl ModuleMap { | Ex::CallResult(_) | Ex::RayQueryProceedResult => {} + // All overrides are retained, so their handles never change. + Ex::Override(_) => {} + // Expressions that contain handles that need to be adjusted. Ex::Constant(ref mut constant) => self.constants.adjust(constant), Ex::ZeroValue(ref mut ty) => self.types.adjust(ty), diff --git a/naga/src/compact/functions.rs b/naga/src/compact/functions.rs index b0d08c7e96..98a23acee0 100644 --- a/naga/src/compact/functions.rs +++ b/naga/src/compact/functions.rs @@ -4,6 +4,7 @@ use super::{FunctionMap, ModuleMap}; pub struct FunctionTracer<'a> { pub function: &'a crate::Function, pub constants: &'a crate::Arena, + pub overrides: &'a crate::Arena, pub types_used: &'a mut HandleSet, pub constants_used: &'a mut HandleSet, @@ -47,6 +48,7 @@ impl<'a> FunctionTracer<'a> { fn as_expression(&mut self) -> super::expressions::ExpressionTracer { super::expressions::ExpressionTracer { constants: self.constants, + overrides: self.overrides, expressions: &self.function.expressions, types_used: self.types_used, diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index 7dfb8ee80d..843a0ccf53 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -54,6 +54,14 @@ pub fn compact(module: &mut crate::Module) { } } + // We treat all overrides as used by definition. + for (_, override_) in module.overrides.iter() { + module_tracer.types_used.insert(override_.ty); + if let Some(init) = override_.init { + module_tracer.const_expressions_used.insert(init); + } + } + // We assume that all functions are used. // // Observe which types, constant expressions, constants, and @@ -158,6 +166,15 @@ pub fn compact(module: &mut crate::Module) { } }); + // Adjust override types and initializers. + log::trace!("adjusting overrides"); + for (_, override_) in module.overrides.iter_mut() { + module_map.types.adjust(&mut override_.ty); + if let Some(init) = override_.init.as_mut() { + module_map.const_expressions.adjust(init); + } + } + // Adjust global variables' types and initializers. log::trace!("adjusting global variables"); for (_, global) in module.global_variables.iter_mut() { @@ -235,6 +252,7 @@ impl<'module> ModuleTracer<'module> { expressions::ExpressionTracer { expressions: &self.module.const_expressions, constants: &self.module.constants, + overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, expressions_used: &mut self.const_expressions_used, @@ -249,6 +267,7 @@ impl<'module> ModuleTracer<'module> { FunctionTracer { function, constants: &self.module.constants, + overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, const_expressions_used: &mut self.const_expressions_used, diff --git a/naga/src/front/spv/function.rs b/naga/src/front/spv/function.rs index 198d9c52dd..8a7e736edd 100644 --- a/naga/src/front/spv/function.rs +++ b/naga/src/front/spv/function.rs @@ -128,6 +128,7 @@ impl> super::Frontend { expressions: &mut fun.expressions, local_arena: &mut fun.local_variables, const_arena: &mut module.constants, + overrides: &mut module.overrides, const_expressions: &mut module.const_expressions, type_arena: &module.types, global_arena: &module.global_variables, @@ -573,6 +574,7 @@ impl<'function> BlockContext<'function> { crate::proc::GlobalCtx { types: self.type_arena, constants: self.const_arena, + overrides: self.overrides, const_expressions: self.const_expressions, } } diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 73b3c7de3d..ada5dc1a9a 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -531,6 +531,7 @@ struct BlockContext<'function> { local_arena: &'function mut Arena, /// Constants arena of the module being processed const_arena: &'function mut Arena, + overrides: &'function mut Arena, const_expressions: &'function mut Arena, /// Type arena of the module being processed type_arena: &'function UniqueArena, @@ -3932,7 +3933,7 @@ impl> Frontend { Op::TypeImage => self.parse_type_image(inst, &mut module), Op::TypeSampledImage => self.parse_type_sampled_image(inst), Op::TypeSampler => self.parse_type_sampler(inst, &mut module), - Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), + Op::Constant => self.parse_constant(inst, &mut module), Op::ConstantComposite => self.parse_composite_constant(inst, &mut module), Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module), Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module), diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 5b3657f1f1..9514723447 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -190,7 +190,7 @@ pub enum Error<'a> { expected: String, got: String, }, - MissingType(Span), + DeclMissingTypeAndInit(Span), MissingAttribute(&'static str, Span), InvalidAtomicPointer(Span), InvalidAtomicOperandType(Span), @@ -269,6 +269,7 @@ pub enum Error<'a> { scalar: String, inner: ConstantEvaluatorError, }, + PipelineConstantIDValue(Span), } impl<'a> Error<'a> { @@ -518,11 +519,11 @@ impl<'a> Error<'a> { notes: vec![], } } - Error::MissingType(name_span) => ParseError { - message: format!("variable `{}` needs a type", &source[name_span]), + Error::DeclMissingTypeAndInit(name_span) => ParseError { + message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]), labels: vec![( name_span, - format!("definition of `{}`", &source[name_span]).into(), + "needs a type specifier or initializer".into(), )], notes: vec![], }, @@ -770,6 +771,14 @@ impl<'a> Error<'a> { format!("the expression should have been converted to have {} scalar type", scalar), ] }, + Error::PipelineConstantIDValue(span) => ParseError { + message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(), + labels: vec![( + span, + "must be between 0 and 65535 inclusive".into(), + )], + notes: vec![], + }, } } } diff --git a/naga/src/front/wgsl/index.rs b/naga/src/front/wgsl/index.rs index a5524fe8f1..593405508f 100644 --- a/naga/src/front/wgsl/index.rs +++ b/naga/src/front/wgsl/index.rs @@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> { ast::GlobalDeclKind::Fn(ref f) => f.name, ast::GlobalDeclKind::Var(ref v) => v.name, ast::GlobalDeclKind::Const(ref c) => c.name, + ast::GlobalDeclKind::Override(ref o) => o.name, ast::GlobalDeclKind::Struct(ref s) => s.name, ast::GlobalDeclKind::Type(ref t) => t.name, } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index c3aa6a932b..87b3732eff 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -786,6 +786,7 @@ enum LoweredGlobalDecl { Function(Handle), Var(Handle), Const(Handle), + Override(Handle), Type(Handle), EntryPoint, } @@ -965,6 +966,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ctx.globals .insert(c.name.name, LoweredGlobalDecl::Const(handle)); } + ast::GlobalDeclKind::Override(ref o) => { + let init = o + .init + .map(|init| self.expression(init, &mut ctx.as_const())) + .transpose()?; + let inferred_type = init + .map(|init| ctx.as_const().register_type(init)) + .transpose()?; + + let explicit_ty = + o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx)) + .transpose()?; + + let id = + o.id.map(|id| self.const_u32(id, &mut ctx.as_const())) + .transpose()?; + + let id = if let Some((id, id_span)) = id { + Some( + u16::try_from(id) + .map_err(|_| Error::PipelineConstantIDValue(id_span))?, + ) + } else { + None + }; + + let ty = match (explicit_ty, inferred_type) { + (Some(explicit_ty), Some(inferred_type)) => { + if explicit_ty == inferred_type { + explicit_ty + } else { + let gctx = ctx.module.to_ctx(); + return Err(Error::InitializationTypeMismatch { + name: o.name.span, + expected: explicit_ty.to_wgsl(&gctx), + got: inferred_type.to_wgsl(&gctx), + }); + } + } + (Some(explicit_ty), None) => explicit_ty, + (None, Some(inferred_type)) => inferred_type, + (None, None) => { + return Err(Error::DeclMissingTypeAndInit(o.name.span)); + } + }; + + let handle = ctx.module.overrides.append( + crate::Override { + name: Some(o.name.name.to_string()), + id, + ty, + init, + }, + span, + ); + + ctx.globals + .insert(o.name.name, LoweredGlobalDecl::Override(handle)); + } ast::GlobalDeclKind::Struct(ref s) => { let handle = self.r#struct(s, span, &mut ctx)?; ctx.globals @@ -1202,7 +1262,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty = explicit_ty; initializer = None; } - (None, None) => return Err(Error::MissingType(v.name.span)), + (None, None) => return Err(Error::DeclMissingTypeAndInit(v.name.span)), } let (const_initializer, initializer) = { @@ -1816,9 +1876,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; Ok(Some(handle)) } - Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => { - Err(Error::Unexpected(function.span, ExpectedToken::Function)) - } + Some( + &LoweredGlobalDecl::Const(_) + | &LoweredGlobalDecl::Override(_) + | &LoweredGlobalDecl::Var(_), + ) => Err(Error::Unexpected(function.span, ExpectedToken::Function)), Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)), Some(&LoweredGlobalDecl::Function(function)) => { let arguments = arguments diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index dbaac523cb..ea8013ee7c 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> { Fn(Function<'a>), Var(GlobalVariable<'a>), Const(Const<'a>), + Override(Override<'a>), Struct(Struct<'a>), Type(TypeAlias<'a>), } @@ -200,6 +201,14 @@ pub struct Const<'a> { pub init: Handle>, } +#[derive(Debug)] +pub struct Override<'a> { + pub name: Ident<'a>, + pub id: Option>>, + pub ty: Option>>, + pub init: Option>>, +} + /// The size of an [`Array`] or [`BindingArray`]. /// /// [`Array`]: Type::Array diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index 51fc2f013b..810e67f9fe 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -2170,6 +2170,7 @@ impl Parser { let mut early_depth_test = ParsedAttribute::default(); let (mut bind_index, mut bind_group) = (ParsedAttribute::default(), ParsedAttribute::default()); + let mut id = ParsedAttribute::default(); let mut dependencies = FastIndexSet::default(); let mut ctx = ExpressionContext { @@ -2193,6 +2194,11 @@ impl Parser { bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?; lexer.expect(Token::Paren(')'))?; } + ("id", name_span) => { + lexer.expect(Token::Paren('('))?; + id.set(self.general_expression(lexer, &mut ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } ("vertex", name_span) => { stage.set(crate::ShaderStage::Vertex, name_span)?; } @@ -2283,6 +2289,30 @@ impl Parser { Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init })) } + (Token::Word("override"), _) => { + let name = lexer.next_ident()?; + + let ty = if lexer.skip(Token::Separator(':')) { + Some(self.type_decl(lexer, &mut ctx)?) + } else { + None + }; + + let init = if lexer.skip(Token::Operation('=')) { + Some(self.general_expression(lexer, &mut ctx)?) + } else { + None + }; + + lexer.expect(Token::Separator(';'))?; + + Some(ast::GlobalDeclKind::Override(ast::Override { + name, + id: id.value, + ty, + init, + })) + } (Token::Word("var"), _) => { let mut var = self.variable_decl(lexer, &mut ctx)?; var.binding = binding.take(); diff --git a/naga/src/front/wgsl/to_wgsl.rs b/naga/src/front/wgsl/to_wgsl.rs index c8331ace09..ba6063ab46 100644 --- a/naga/src/front/wgsl/to_wgsl.rs +++ b/naga/src/front/wgsl/to_wgsl.rs @@ -226,6 +226,7 @@ mod tests { let gctx = crate::proc::GlobalCtx { types: &types, constants: &crate::Arena::new(), + overrides: &crate::Arena::new(), const_expressions: &crate::Arena::new(), }; let array = crate::TypeInner::Array { diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 6ec1fe8047..30d1041be4 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -175,7 +175,7 @@ tree. A Naga *constant expression* is one of the following [`Expression`] variants, whose operands (if any) are also constant expressions: - [`Literal`] -- [`Constant`], for [`Constant`s][const_type] whose `override` is `None` +- [`Constant`], for [`Constant`]s - [`ZeroValue`], for fixed-size types - [`Compose`] - [`Access`] @@ -194,8 +194,7 @@ A constant expression can be evaluated at module translation time. ## Override expressions A Naga *override expression* is the same as a [constant expression], -except that it is also allowed to refer to [`Constant`s][const_type] -whose `override` is something other than `None`. +except that it is also allowed to reference other [`Override`]s. An override expression can be evaluated at pipeline creation time. @@ -238,8 +237,6 @@ An override expression can be evaluated at pipeline creation time. [`Math`]: Expression::Math [`As`]: Expression::As -[const_type]: Constant - [constant expression]: index.html#constant-expressions */ @@ -886,6 +883,25 @@ pub enum Literal { AbstractFloat(f64), } +/// Pipeline-overridable constant. +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct Override { + pub name: Option, + /// Pipeline Constant ID. + pub id: Option, + pub ty: Handle, + + /// The default value of the pipeline-overridable constant. + /// + /// This [`Handle`] refers to [`Module::const_expressions`], not + /// any [`Function::expressions`] arena. + pub init: Option>, +} + /// Constant value. #[derive(Debug, PartialEq)] #[cfg_attr(feature = "clone", derive(Clone))] @@ -900,13 +916,6 @@ pub struct Constant { /// /// This [`Handle`] refers to [`Module::const_expressions`], not /// any [`Function::expressions`] arena. - /// - /// If `override` is `None`, then this must be a Naga - /// [constant expression]. Otherwise, this may be a Naga - /// [override expression] or [constant expression]. - /// - /// [constant expression]: index.html#constant-expressions - /// [override expression]: index.html#override-expressions pub init: Handle, } @@ -1292,6 +1301,8 @@ pub enum Expression { Literal(Literal), /// Constant value. Constant(Handle), + /// Pipeline-overridable constant. + Override(Handle), /// Zero value of a type. ZeroValue(Handle), /// Composite expression. @@ -2034,6 +2045,8 @@ pub struct Module { pub special_types: SpecialTypes, /// Arena for the constants defined in this module. pub constants: Arena, + /// Arena for the pipeline-overridable constants defined in this module. + pub overrides: Arena, /// Arena for the global variables defined in this module. pub global_variables: Arena, /// [Constant expressions] and [override expressions] used by this module. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 548922dc58..53718fd821 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -1,7 +1,7 @@ use crate::{ arena::{Arena, Handle, UniqueArena}, - ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner, - UnaryOperator, + ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type, + TypeInner, UnaryOperator, }; #[derive(Debug)] @@ -43,6 +43,9 @@ pub struct ConstantEvaluator<'a> { /// The module's constant arena. constants: &'a Arena, + /// The module's override arena. + overrides: &'a Arena, + /// The arena to which we are contributing expressions. expressions: &'a mut Arena, @@ -208,6 +211,7 @@ impl<'a> ConstantEvaluator<'a> { behavior, types: &mut module.types, constants: &module.constants, + overrides: &module.overrides, expressions: &mut module.const_expressions, function_local_data: None, } @@ -267,6 +271,7 @@ impl<'a> ConstantEvaluator<'a> { behavior, types: &mut module.types, constants: &module.constants, + overrides: &module.overrides, expressions, function_local_data: Some(FunctionLocalData { const_expressions: &module.const_expressions, @@ -281,6 +286,7 @@ impl<'a> ConstantEvaluator<'a> { crate::proc::GlobalCtx { types: self.types, constants: self.constants, + overrides: self.overrides, const_expressions: match self.function_local_data { Some(ref data) => data.const_expressions, None => self.expressions, @@ -357,6 +363,9 @@ impl<'a> ConstantEvaluator<'a> { // This is mainly done to avoid having constants pointing to other constants. Ok(self.constants[c].init) } + Expression::Override(_) => Err(ConstantEvaluatorError::NotImplemented( + "overrides are WIP".into(), + )), Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { self.register_evaluated_expr(expr.clone(), span) } @@ -1615,6 +1624,7 @@ mod tests { fn unary_op() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let scalar_ty = types.insert( @@ -1693,6 +1703,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; @@ -1744,6 +1755,7 @@ mod tests { fn cast() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let scalar_ty = types.insert( @@ -1776,6 +1788,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; @@ -1794,6 +1807,7 @@ mod tests { fn access() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let matrix_ty = types.insert( @@ -1891,6 +1905,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; @@ -1944,6 +1959,7 @@ mod tests { fn compose_of_constants() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let i32_ty = types.insert( @@ -1981,6 +1997,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; @@ -2023,6 +2040,7 @@ mod tests { fn splat_of_constant() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let i32_ty = types.insert( @@ -2060,6 +2078,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 7dc1766b3c..831eb87040 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -633,6 +633,7 @@ impl crate::Module { GlobalCtx { types: &self.types, constants: &self.constants, + overrides: &self.overrides, const_expressions: &self.const_expressions, } } @@ -648,6 +649,7 @@ pub(super) enum U32EvalError { pub struct GlobalCtx<'a> { pub types: &'a crate::UniqueArena, pub constants: &'a crate::Arena, + pub overrides: &'a crate::Arena, pub const_expressions: &'a crate::Arena, } diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 9c4403445c..845b35cb4d 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -185,6 +185,7 @@ pub enum ResolveError { pub struct ResolveContext<'a> { pub constants: &'a Arena, + pub overrides: &'a Arena, pub types: &'a UniqueArena, pub special_types: &'a crate::SpecialTypes, pub global_vars: &'a Arena, @@ -202,6 +203,7 @@ impl<'a> ResolveContext<'a> { ) -> Self { Self { constants: &module.constants, + overrides: &module.overrides, types: &module.types, special_types: &module.special_types, global_vars: &module.global_variables, @@ -407,6 +409,7 @@ impl<'a> ResolveContext<'a> { }, crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()), crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty), + crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty), crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty), crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty), crate::Expression::FunctionArgument(index) => { diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index df6fc5e9b0..17c76b2738 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -527,7 +527,7 @@ impl FunctionInfo { non_uniform_result: self.add_ref(vector), requirements: UniformityRequirements::empty(), }, - E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(), + E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(), E::Compose { ref components, .. } => { let non_uniform_result = components .iter() @@ -1139,6 +1139,7 @@ fn uniform_control_flow() { }; let resolve_context = ResolveContext { constants: &Arena::new(), + overrides: &Arena::new(), types: &type_arena, special_types: &crate::SpecialTypes::default(), global_vars: &global_var_arena, diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 54d8b3b357..f41948b910 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -343,7 +343,7 @@ impl super::Validator { self.validate_literal(literal)?; ShaderStages::all() } - E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(), E::Compose { ref components, ty } => { validate_compose( ty, diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 1884c01303..0643b1c9f5 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -31,6 +31,7 @@ impl super::Validator { pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> { let &crate::Module { ref constants, + ref overrides, ref entry_points, ref functions, ref global_variables, @@ -68,7 +69,7 @@ impl super::Validator { } for handle_and_expr in const_expressions.iter() { - Self::validate_const_expression_handles(handle_and_expr, constants, types)?; + Self::validate_const_expression_handles(handle_and_expr, constants, overrides, types)?; } let validate_type = |handle| Self::validate_type_handle(handle, types); @@ -81,6 +82,19 @@ impl super::Validator { validate_const_expr(init)?; } + for (_handle, override_) in overrides.iter() { + let &crate::Override { + name: _, + id: _, + ty, + init, + } = override_; + validate_type(ty)?; + if let Some(init_expr) = init { + validate_const_expr(init_expr)?; + } + } + for (_handle, global_variable) in global_variables.iter() { let &crate::GlobalVariable { name: _, @@ -135,6 +149,7 @@ impl super::Validator { Self::validate_expression_handles( handle_and_expr, constants, + overrides, const_expressions, types, local_variables, @@ -181,6 +196,13 @@ impl super::Validator { handle.check_valid_for(constants).map(|_| ()) } + fn validate_override_handle( + handle: Handle, + overrides: &Arena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(overrides).map(|_| ()) + } + fn validate_expression_handle( handle: Handle, expressions: &Arena, @@ -198,9 +220,11 @@ impl super::Validator { fn validate_const_expression_handles( (handle, expression): (Handle, &crate::Expression), constants: &Arena, + overrides: &Arena, types: &UniqueArena, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_type = |handle| Self::validate_type_handle(handle, types); match *expression { @@ -209,6 +233,12 @@ impl super::Validator { validate_constant(constant)?; handle.check_dep(constants[constant].init)?; } + crate::Expression::Override(override_) => { + validate_override(override_)?; + if let Some(init) = overrides[override_].init { + handle.check_dep(init)?; + } + } crate::Expression::ZeroValue(ty) => { validate_type(ty)?; } @@ -225,6 +255,7 @@ impl super::Validator { fn validate_expression_handles( (handle, expression): (Handle, &crate::Expression), constants: &Arena, + overrides: &Arena, const_expressions: &Arena, types: &UniqueArena, local_variables: &Arena, @@ -234,6 +265,7 @@ impl super::Validator { current_function: Option>, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_const_expr = |handle| Self::validate_expression_handle(handle, const_expressions); let validate_type = |handle| Self::validate_type_handle(handle, types); @@ -255,6 +287,9 @@ impl super::Validator { crate::Expression::Constant(constant) => { validate_constant(constant)?; } + crate::Expression::Override(override_) => { + validate_override(override_)?; + } crate::Expression::ZeroValue(ty) => { validate_type(ty)?; } @@ -659,6 +694,7 @@ fn constant_deps() { let mut const_exprs = Arena::new(); let mut fun_exprs = Arena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let i32_handle = types.insert( Type { @@ -686,6 +722,7 @@ fn constant_deps() { assert!(super::Validator::validate_const_expression_handles( handle_and_expr, &constants, + &overrides, &types, ) .is_err()); diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 70a4d39d2a..f5a2414e2f 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -182,6 +182,16 @@ pub enum ConstantError { NonConstructibleType, } +#[derive(Clone, Debug, thiserror::Error)] +pub enum OverrideError { + #[error("The type doesn't match the override")] + InvalidType, + #[error("The type is not constructible")] + NonConstructibleType, + #[error("The type is not a scalar")] + TypeNotScalar, +} + #[derive(Clone, Debug, thiserror::Error)] pub enum ValidationError { #[error(transparent)] @@ -205,6 +215,12 @@ pub enum ValidationError { name: String, source: ConstantError, }, + #[error("Override {handle:?} '{name}' is invalid")] + Override { + handle: Handle, + name: String, + source: OverrideError, + }, #[error("Global variable {handle:?} '{name}' is invalid")] GlobalVariable { handle: Handle, @@ -327,6 +343,35 @@ impl Validator { Ok(()) } + fn validate_override( + &self, + handle: Handle, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), OverrideError> { + let o = &gctx.overrides[handle]; + + let type_info = &self.types[o.ty.index()]; + if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { + return Err(OverrideError::NonConstructibleType); + } + + let decl_ty = &gctx.types[o.ty].inner; + match decl_ty { + &crate::TypeInner::Scalar(_) => {} + _ => return Err(OverrideError::TypeNotScalar), + } + + if let Some(init) = o.init { + let init_ty = mod_info[init].inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(OverrideError::InvalidType); + } + } + + Ok(()) + } + /// Check the given module to be valid. pub fn validate( &mut self, @@ -404,6 +449,18 @@ impl Validator { .with_span_handle(handle, &module.constants) })? } + + for (handle, override_) in module.overrides.iter() { + self.validate_override(handle, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::Override { + handle, + name: override_.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.overrides) + })? + } } for (var_handle, var) in module.global_variables.iter() { diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl new file mode 100644 index 0000000000..803269a656 --- /dev/null +++ b/naga/tests/in/overrides.wgsl @@ -0,0 +1,14 @@ +@id(0) override has_point_light: bool = true; // Algorithmic control +@id(1200) override specular_param: f32 = 2.3; // Numeric control +@id(1300) override gain: f32; // Must be overridden + override width: f32 = 0.0; // Specified at the API level using + // the name "width". + override depth: f32; // Specified at the API level using + // the name "depth". + // Must be overridden. + // override height = 2 * depth; // The default value + // (if not set at the API level), + // depends on another + // overridable constant. + +override inferred_f32 = 2.718; diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron new file mode 100644 index 0000000000..9ad1b3914e --- /dev/null +++ b/naga/tests/out/analysis/overrides.info.ron @@ -0,0 +1,26 @@ +( + type_flags: [ + ("DATA | SIZED | COPY | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), + ], + functions: [], + entry_points: [], + const_expression_types: [ + Value(Scalar(( + kind: Bool, + width: 1, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), + ], +) \ No newline at end of file diff --git a/naga/tests/out/ir/access.compact.ron b/naga/tests/out/ir/access.compact.ron index 0670534e90..37ace5283f 100644 --- a/naga/tests/out/ir/access.compact.ron +++ b/naga/tests/out/ir/access.compact.ron @@ -324,6 +324,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("global_const"), diff --git a/naga/tests/out/ir/access.ron b/naga/tests/out/ir/access.ron index 0670534e90..37ace5283f 100644 --- a/naga/tests/out/ir/access.ron +++ b/naga/tests/out/ir/access.ron @@ -324,6 +324,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("global_const"), diff --git a/naga/tests/out/ir/collatz.compact.ron b/naga/tests/out/ir/collatz.compact.ron index cfc3bfa0ee..fe4af55c1b 100644 --- a/naga/tests/out/ir/collatz.compact.ron +++ b/naga/tests/out/ir/collatz.compact.ron @@ -46,6 +46,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("v_indices"), diff --git a/naga/tests/out/ir/collatz.ron b/naga/tests/out/ir/collatz.ron index cfc3bfa0ee..fe4af55c1b 100644 --- a/naga/tests/out/ir/collatz.ron +++ b/naga/tests/out/ir/collatz.ron @@ -46,6 +46,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("v_indices"), diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron new file mode 100644 index 0000000000..5ac9ade6f6 --- /dev/null +++ b/naga/tests/out/ir/overrides.compact.ron @@ -0,0 +1,71 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("has_point_light"), + id: Some(0), + ty: 1, + init: Some(1), + ), + ( + name: Some("specular_param"), + id: Some(1200), + ty: 2, + init: Some(2), + ), + ( + name: Some("gain"), + id: Some(1300), + ty: 2, + init: None, + ), + ( + name: Some("width"), + id: None, + ty: 2, + init: Some(3), + ), + ( + name: Some("depth"), + id: None, + ty: 2, + init: None, + ), + ( + name: Some("inferred_f32"), + id: None, + ty: 2, + init: Some(4), + ), + ], + global_variables: [], + const_expressions: [ + Literal(Bool(true)), + Literal(F32(2.3)), + Literal(F32(0.0)), + Literal(F32(2.718)), + ], + functions: [], + entry_points: [], +) \ No newline at end of file diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron new file mode 100644 index 0000000000..5ac9ade6f6 --- /dev/null +++ b/naga/tests/out/ir/overrides.ron @@ -0,0 +1,71 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("has_point_light"), + id: Some(0), + ty: 1, + init: Some(1), + ), + ( + name: Some("specular_param"), + id: Some(1200), + ty: 2, + init: Some(2), + ), + ( + name: Some("gain"), + id: Some(1300), + ty: 2, + init: None, + ), + ( + name: Some("width"), + id: None, + ty: 2, + init: Some(3), + ), + ( + name: Some("depth"), + id: None, + ty: 2, + init: None, + ), + ( + name: Some("inferred_f32"), + id: None, + ty: 2, + init: Some(4), + ), + ], + global_variables: [], + const_expressions: [ + Literal(Bool(true)), + Literal(F32(2.3)), + Literal(F32(0.0)), + Literal(F32(2.718)), + ], + functions: [], + entry_points: [], +) \ No newline at end of file diff --git a/naga/tests/out/ir/shadow.compact.ron b/naga/tests/out/ir/shadow.compact.ron index 4e65180691..fab0f1e2f6 100644 --- a/naga/tests/out/ir/shadow.compact.ron +++ b/naga/tests/out/ir/shadow.compact.ron @@ -253,6 +253,7 @@ init: 22, ), ], + overrides: [], global_variables: [ ( name: Some("t_shadow"), diff --git a/naga/tests/out/ir/shadow.ron b/naga/tests/out/ir/shadow.ron index 0b2310284a..9acbbdaadd 100644 --- a/naga/tests/out/ir/shadow.ron +++ b/naga/tests/out/ir/shadow.ron @@ -456,6 +456,7 @@ init: 38, ), ], + overrides: [], global_variables: [ ( name: Some("t_shadow"), diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 520fad46e2..37631f8a3c 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -806,6 +806,14 @@ fn convert_wgsl() { "abstract-types-operators", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL, ), + ( + "overrides", + Targets::IR | Targets::ANALYSIS, // | Targets::SPIRV + // | Targets::METAL + // | Targets::GLSL + // | Targets::HLSL + // | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() { diff --git a/naga/tests/wgsl_errors.rs b/naga/tests/wgsl_errors.rs index 2f62491b3f..42e17dac37 100644 --- a/naga/tests/wgsl_errors.rs +++ b/naga/tests/wgsl_errors.rs @@ -570,11 +570,11 @@ fn local_var_missing_type() { var x; } "#, - r#"error: variable `x` needs a type + r#"error: declaration of `x` needs a type specifier or initializer ┌─ wgsl:3:21 │ 3 │ var x; - │ ^ definition of `x` + │ ^ needs a type specifier or initializer "#, );