Skip to content

Commit

Permalink
[wgsl-in] add support for override declarations (#4793)
Browse files Browse the repository at this point in the history
Co-authored-by: Jim Blandy <[email protected]>
  • Loading branch information
teoxoy and jimblandy committed Jan 8, 2024
1 parent 496bcbd commit 6a494a8
Show file tree
Hide file tree
Showing 37 changed files with 515 additions and 28 deletions.
1 change: 1 addition & 0 deletions naga/src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ fn write_function_expressions(
let (label, color_id) = match *expression {
E::Literal(_) => ("Literal".into(), 2),
E::Constant(_) => ("Constant".into(), 2),
E::Override(_) => ("Override".into(), 2),
E::ZeroValue(_) => ("ZeroValue".into(), 2),
E::Compose { ref components, .. } => {
payload = Some(Payload::Arguments(components));
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2530,6 +2530,7 @@ impl<'a, W: Write> Writer<'a, W> {
|writer, expr| writer.write_expr(expr, ctx),
)?;
}
Expression::Override(_) => return Err(Error::Custom("overrides are WIP".into())),
// `Access` is applied to arrays, vectors and matrices and is written as indexing
Expression::Access { base, index } => {
self.write_expr(base, ctx)?;
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2156,6 +2156,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
Expression::Override(_) => {
return Err(Error::Unimplemented("overrides are WIP".into()))
}
// All of the multiplication can be expressed as `mul`,
// except vector * vector, which needs to use the "*" operator.
Expression::Binary {
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,9 @@ impl<W: Write> Writer<W> {
|writer, context, expr| writer.put_expression(expr, context, true),
)?;
}
crate::Expression::Override(_) => {
return Err(Error::FeatureNotImplemented("overrides are WIP".into()))
}
crate::Expression::Access { base, .. }
| crate::Expression::AccessIndex { base, .. } => {
// This is an acceptable place to generate a `ReadZeroSkipWrite` check.
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ impl<'w> BlockContext<'w> {
let init = self.ir_module.constants[handle].init;
self.writer.constant_ids[init.index()]
}
crate::Expression::Override(_) => {
return Err(Error::FeatureNotImplemented("overrides are WIP"))
}
crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
crate::Expression::Compose { ty, ref components } => {
self.temp_list.clear();
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,9 @@ impl<W: Write> Writer<W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
Expression::Override(_) => {
return Err(Error::Unimplemented("overrides are WIP".into()))
}
Expression::FunctionArgument(pos) => {
let name_key = func_ctx.argument_key(pos);
let name = &self.names[&name_key];
Expand Down
9 changes: 9 additions & 0 deletions naga/src/compact/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::arena::{Arena, Handle};

pub struct ExpressionTracer<'tracer> {
pub constants: &'tracer Arena<crate::Constant>,
pub overrides: &'tracer Arena<crate::Override>,

/// The arena in which we are currently tracing expressions.
pub expressions: &'tracer Arena<crate::Expression>,
Expand Down Expand Up @@ -88,6 +89,11 @@ impl<'tracer> ExpressionTracer<'tracer> {
None => self.expressions_used.insert(init),
}
}
Ex::Override(_) => {
// All overrides are considered used by definition. We mark
// their types and initialization expressions as used in
// `compact::compact`, so we have no more work to do here.
}
Ex::ZeroValue(ty) => self.types_used.insert(ty),
Ex::Compose { ty, ref components } => {
self.types_used.insert(ty);
Expand Down Expand Up @@ -219,6 +225,9 @@ impl ModuleMap {
| Ex::CallResult(_)
| Ex::RayQueryProceedResult => {}

// All overrides are retained, so their handles never change.
Ex::Override(_) => {}

// Expressions that contain handles that need to be adjusted.
Ex::Constant(ref mut constant) => self.constants.adjust(constant),
Ex::ZeroValue(ref mut ty) => self.types.adjust(ty),
Expand Down
2 changes: 2 additions & 0 deletions naga/src/compact/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::{FunctionMap, ModuleMap};
pub struct FunctionTracer<'a> {
pub function: &'a crate::Function,
pub constants: &'a crate::Arena<crate::Constant>,
pub overrides: &'a crate::Arena<crate::Override>,

pub types_used: &'a mut HandleSet<crate::Type>,
pub constants_used: &'a mut HandleSet<crate::Constant>,
Expand Down Expand Up @@ -47,6 +48,7 @@ impl<'a> FunctionTracer<'a> {
fn as_expression(&mut self) -> super::expressions::ExpressionTracer {
super::expressions::ExpressionTracer {
constants: self.constants,
overrides: self.overrides,
expressions: &self.function.expressions,

types_used: self.types_used,
Expand Down
19 changes: 19 additions & 0 deletions naga/src/compact/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ pub fn compact(module: &mut crate::Module) {
}
}

// We treat all overrides as used by definition.
for (_, override_) in module.overrides.iter() {
module_tracer.types_used.insert(override_.ty);
if let Some(init) = override_.init {
module_tracer.const_expressions_used.insert(init);
}
}

// We assume that all functions are used.
//
// Observe which types, constant expressions, constants, and
Expand Down Expand Up @@ -158,6 +166,15 @@ pub fn compact(module: &mut crate::Module) {
}
});

// Adjust override types and initializers.
log::trace!("adjusting overrides");
for (_, override_) in module.overrides.iter_mut() {
module_map.types.adjust(&mut override_.ty);
if let Some(init) = override_.init.as_mut() {
module_map.const_expressions.adjust(init);
}
}

// Adjust global variables' types and initializers.
log::trace!("adjusting global variables");
for (_, global) in module.global_variables.iter_mut() {
Expand Down Expand Up @@ -235,6 +252,7 @@ impl<'module> ModuleTracer<'module> {
expressions::ExpressionTracer {
expressions: &self.module.const_expressions,
constants: &self.module.constants,
overrides: &self.module.overrides,
types_used: &mut self.types_used,
constants_used: &mut self.constants_used,
expressions_used: &mut self.const_expressions_used,
Expand All @@ -249,6 +267,7 @@ impl<'module> ModuleTracer<'module> {
FunctionTracer {
function,
constants: &self.module.constants,
overrides: &self.module.overrides,
types_used: &mut self.types_used,
constants_used: &mut self.constants_used,
const_expressions_used: &mut self.const_expressions_used,
Expand Down
2 changes: 2 additions & 0 deletions naga/src/front/spv/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
expressions: &mut fun.expressions,
local_arena: &mut fun.local_variables,
const_arena: &mut module.constants,
overrides: &mut module.overrides,
const_expressions: &mut module.const_expressions,
type_arena: &module.types,
global_arena: &module.global_variables,
Expand Down Expand Up @@ -573,6 +574,7 @@ impl<'function> BlockContext<'function> {
crate::proc::GlobalCtx {
types: self.type_arena,
constants: self.const_arena,
overrides: self.overrides,
const_expressions: self.const_expressions,
}
}
Expand Down
3 changes: 2 additions & 1 deletion naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ struct BlockContext<'function> {
local_arena: &'function mut Arena<crate::LocalVariable>,
/// Constants arena of the module being processed
const_arena: &'function mut Arena<crate::Constant>,
overrides: &'function mut Arena<crate::Override>,
const_expressions: &'function mut Arena<crate::Expression>,
/// Type arena of the module being processed
type_arena: &'function UniqueArena<crate::Type>,
Expand Down Expand Up @@ -3932,7 +3933,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
Op::TypeImage => self.parse_type_image(inst, &mut module),
Op::TypeSampledImage => self.parse_type_sampled_image(inst),
Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
Op::Constant => self.parse_constant(inst, &mut module),
Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module),
Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module),
Expand Down
17 changes: 13 additions & 4 deletions naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ pub enum Error<'a> {
expected: String,
got: String,
},
MissingType(Span),
DeclMissingTypeAndInit(Span),
MissingAttribute(&'static str, Span),
InvalidAtomicPointer(Span),
InvalidAtomicOperandType(Span),
Expand Down Expand Up @@ -269,6 +269,7 @@ pub enum Error<'a> {
scalar: String,
inner: ConstantEvaluatorError,
},
PipelineConstantIDValue(Span),
}

impl<'a> Error<'a> {
Expand Down Expand Up @@ -518,11 +519,11 @@ impl<'a> Error<'a> {
notes: vec![],
}
}
Error::MissingType(name_span) => ParseError {
message: format!("variable `{}` needs a type", &source[name_span]),
Error::DeclMissingTypeAndInit(name_span) => ParseError {
message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]),
labels: vec![(
name_span,
format!("definition of `{}`", &source[name_span]).into(),
"needs a type specifier or initializer".into(),
)],
notes: vec![],
},
Expand Down Expand Up @@ -770,6 +771,14 @@ impl<'a> Error<'a> {
format!("the expression should have been converted to have {} scalar type", scalar),
]
},
Error::PipelineConstantIDValue(span) => ParseError {
message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(),
labels: vec![(
span,
"must be between 0 and 65535 inclusive".into(),
)],
notes: vec![],
},
}
}
}
1 change: 1 addition & 0 deletions naga/src/front/wgsl/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
ast::GlobalDeclKind::Fn(ref f) => f.name,
ast::GlobalDeclKind::Var(ref v) => v.name,
ast::GlobalDeclKind::Const(ref c) => c.name,
ast::GlobalDeclKind::Override(ref o) => o.name,
ast::GlobalDeclKind::Struct(ref s) => s.name,
ast::GlobalDeclKind::Type(ref t) => t.name,
}
Expand Down
70 changes: 66 additions & 4 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,7 @@ enum LoweredGlobalDecl {
Function(Handle<crate::Function>),
Var(Handle<crate::GlobalVariable>),
Const(Handle<crate::Constant>),
Override(Handle<crate::Override>),
Type(Handle<crate::Type>),
EntryPoint,
}
Expand Down Expand Up @@ -965,6 +966,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ctx.globals
.insert(c.name.name, LoweredGlobalDecl::Const(handle));
}
ast::GlobalDeclKind::Override(ref o) => {
let init = o
.init
.map(|init| self.expression(init, &mut ctx.as_const()))
.transpose()?;
let inferred_type = init
.map(|init| ctx.as_const().register_type(init))
.transpose()?;

let explicit_ty =
o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx))
.transpose()?;

let id =
o.id.map(|id| self.const_u32(id, &mut ctx.as_const()))
.transpose()?;

let id = if let Some((id, id_span)) = id {
Some(
u16::try_from(id)
.map_err(|_| Error::PipelineConstantIDValue(id_span))?,
)
} else {
None
};

let ty = match (explicit_ty, inferred_type) {
(Some(explicit_ty), Some(inferred_type)) => {
if explicit_ty == inferred_type {
explicit_ty
} else {
let gctx = ctx.module.to_ctx();
return Err(Error::InitializationTypeMismatch {
name: o.name.span,
expected: explicit_ty.to_wgsl(&gctx),
got: inferred_type.to_wgsl(&gctx),
});
}
}
(Some(explicit_ty), None) => explicit_ty,
(None, Some(inferred_type)) => inferred_type,
(None, None) => {
return Err(Error::DeclMissingTypeAndInit(o.name.span));
}
};

let handle = ctx.module.overrides.append(
crate::Override {
name: Some(o.name.name.to_string()),
id,
ty,
init,
},
span,
);

ctx.globals
.insert(o.name.name, LoweredGlobalDecl::Override(handle));
}
ast::GlobalDeclKind::Struct(ref s) => {
let handle = self.r#struct(s, span, &mut ctx)?;
ctx.globals
Expand Down Expand Up @@ -1202,7 +1262,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty = explicit_ty;
initializer = None;
}
(None, None) => return Err(Error::MissingType(v.name.span)),
(None, None) => return Err(Error::DeclMissingTypeAndInit(v.name.span)),
}

let (const_initializer, initializer) = {
Expand Down Expand Up @@ -1816,9 +1876,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
Ok(Some(handle))
}
Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => {
Err(Error::Unexpected(function.span, ExpectedToken::Function))
}
Some(
&LoweredGlobalDecl::Const(_)
| &LoweredGlobalDecl::Override(_)
| &LoweredGlobalDecl::Var(_),
) => Err(Error::Unexpected(function.span, ExpectedToken::Function)),
Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)),
Some(&LoweredGlobalDecl::Function(function)) => {
let arguments = arguments
Expand Down
9 changes: 9 additions & 0 deletions naga/src/front/wgsl/parse/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> {
Fn(Function<'a>),
Var(GlobalVariable<'a>),
Const(Const<'a>),
Override(Override<'a>),
Struct(Struct<'a>),
Type(TypeAlias<'a>),
}
Expand Down Expand Up @@ -200,6 +201,14 @@ pub struct Const<'a> {
pub init: Handle<Expression<'a>>,
}

#[derive(Debug)]
pub struct Override<'a> {
pub name: Ident<'a>,
pub id: Option<Handle<Expression<'a>>>,
pub ty: Option<Handle<Type<'a>>>,
pub init: Option<Handle<Expression<'a>>>,
}

/// The size of an [`Array`] or [`BindingArray`].
///
/// [`Array`]: Type::Array
Expand Down
Loading

0 comments on commit 6a494a8

Please sign in to comment.