From 8b30632dbad78c204a709709227f9cdaf6cede29 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sun, 22 Oct 2023 12:22:17 -0700 Subject: [PATCH] [wgsl-in] Unify ConcreteConstructor and ConcreteConstructorHandle. Replace the `ConcreteConstructor` and `ConcreteConstructorHandle` types in `front::wgsl::lower::construction` with a single type `Constructor` with a type parameter that determines how it refers to Naga types. In the single-argument vector construction case, construct `Compose` expressions when the lengths don't match, and leave the problem for validation to catch; `validate_compose` is a better place to invest time in improving diagnostics (and arguably producess more specific error messages already). --- src/front/wgsl/lower/construction.rs | 215 +++++++++++++++------------ 1 file changed, 122 insertions(+), 93 deletions(-) diff --git a/src/front/wgsl/lower/construction.rs b/src/front/wgsl/lower/construction.rs index 26b2c2f59d..6227167850 100644 --- a/src/front/wgsl/lower/construction.rs +++ b/src/front/wgsl/lower/construction.rs @@ -6,45 +6,54 @@ use crate::{Handle, Span}; use crate::front::wgsl::error::Error; use crate::front::wgsl::lower::{ExpressionContext, Lowerer}; -enum ConcreteConstructorHandle { - PartialVector { - size: crate::VectorSize, - }, +/// A cooked form of `ast::ConstructorType` that uses Naga types whenever +/// possible. +enum Constructor { + /// A vector construction whose component type is inferred from the + /// argument: `vec3(1.0)`. + PartialVector { size: crate::VectorSize }, + + /// A matrix construction whose component type is inferred from the + /// argument: `mat2x2(1,2,3,4)`. PartialMatrix { columns: crate::VectorSize, rows: crate::VectorSize, }, + + /// An array whose component type and size are inferred from the arguments: + /// `array(3,4,5)`. PartialArray, - Type(Handle), + + /// A known Naga type. + /// + /// When we match on this type, we need to see the `TypeInner` here, but at + /// the point that we build this value we'll still need mutable access to + /// the module later. To avoid borrowing from the module, the type parameter + /// `T` is `Handle` initially. Then we use `borrow_inner` to produce a + /// version holding a tuple `(Handle, &TypeInner)`. + Type(T), } -impl ConcreteConstructorHandle { - fn borrow<'a>(&self, module: &'a crate::Module) -> ConcreteConstructor<'a> { - match *self { - Self::PartialVector { size } => ConcreteConstructor::PartialVector { size }, - Self::PartialMatrix { columns, rows } => { - ConcreteConstructor::PartialMatrix { columns, rows } +impl Constructor> { + /// Return an equivalent `Constructor` value that includes borrowed + /// `TypeInner` values alongside any type handles. + fn borrow_inner( + self, + module: &crate::Module, + ) -> Constructor<(Handle, &crate::TypeInner)> { + match self { + Constructor::PartialVector { size } => Constructor::PartialVector { size }, + Constructor::PartialMatrix { columns, rows } => { + Constructor::PartialMatrix { columns, rows } } - Self::PartialArray => ConcreteConstructor::PartialArray, - Self::Type(handle) => ConcreteConstructor::Type(handle, &module.types[handle].inner), + Constructor::PartialArray => Constructor::PartialArray, + Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)), } } } -enum ConcreteConstructor<'a> { - PartialVector { - size: crate::VectorSize, - }, - PartialMatrix { - columns: crate::VectorSize, - rows: crate::VectorSize, - }, - PartialArray, - Type(Handle, &'a crate::TypeInner), -} - -impl ConcreteConstructorHandle { - fn to_error_string(&self, ctx: &mut ExpressionContext) -> String { +impl Constructor<(Handle, &crate::TypeInner)> { + fn to_error_string(&self, ctx: &ExpressionContext) -> String { match *self { Self::PartialVector { size } => { format!("vec{}", size as u32,) @@ -53,7 +62,7 @@ impl ConcreteConstructorHandle { format!("mat{}x{}", columns as u32, rows as u32,) } Self::PartialArray => "array".to_string(), - Self::Type(ty) => ctx.format_type(ty), + Self::Type((handle, _inner)) => ctx.format_type(handle), } } } @@ -80,7 +89,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { components: &[Handle>], ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { - let constructor_h = self.constructor(constructor, ctx)?; + let constructor = self.constructor(constructor, ctx)?; let components: Vec<(Handle, Span)> = components .iter() @@ -96,7 +105,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { None => None, }; - let constructor = constructor_h.borrow(ctx.module); + // Even though we computed `constructor` above, wait until here to borrow + // a reference to the `TypeInner`, so that the component-handling code + // above can have mutable access to the type arena. + let constructor = constructor.borrow_inner(ctx.module); let expr = match (&components[..], first_component_ty, constructor) { // INTERNAL ERRORS @@ -117,18 +129,26 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // WELL-FORMED CASES // Empty constructor - (&[], None, dst_ty) => match dst_ty { - ConcreteConstructor::Type(ty, _) => { - return ctx.append_expression(crate::Expression::ZeroValue(ty), span) + (&[], None, constructor) => { + match constructor { + Constructor::Type((result_ty, _)) => { + return ctx.append_expression(crate::Expression::ZeroValue(result_ty), span) + } + Constructor::PartialVector { .. } + | Constructor::PartialMatrix { .. } + | Constructor::PartialArray => { + // We have no arguments from which to infer the result type, so + // partial constructors aren't acceptable here. + return Err(Error::TypeNotInferrable(ty_span)); + } } - _ => return Err(Error::TypeNotInferrable(ty_span)), - }, + } // Scalar constructor & conversion (scalar -> scalar) ( &[(component, _span)], Some(&crate::TypeInner::Scalar { .. }), - ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { kind, width }), + Constructor::Type((_, &crate::TypeInner::Scalar { kind, width })), ) => crate::Expression::As { expr: component, kind, @@ -139,26 +159,39 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ( &[(component, _span)], Some(&crate::TypeInner::Vector { size: src_size, .. }), - ConcreteConstructor::Type( - _, + Constructor::Type(( + ty, &crate::TypeInner::Vector { size: dst_size, kind: dst_kind, width: dst_width, }, - ), - ) if dst_size == src_size => crate::Expression::As { - expr: component, - kind: dst_kind, - convert: Some(dst_width), - }, + )), + ) => { + // If the lengths match, this is an `As` conversion. Otherwise, + // go ahead and build a `Compose` expression, which (unlike + // `As`) spells out the expected type. The validator will catch + // the problem and generate a helpful error message. + if src_size == dst_size { + crate::Expression::As { + expr: component, + kind: dst_kind, + convert: Some(dst_width), + } + } else { + crate::Expression::Compose { + ty, + components: vec![component], + } + } + } // Vector conversion (vector -> vector) - partial ( &[(component, _span)], Some(&crate::TypeInner::Vector { size: src_size, .. }), - ConcreteConstructor::PartialVector { size: dst_size }, - ) if dst_size == src_size => { + Constructor::PartialVector { size: dst_size }, + ) if src_size == dst_size => { // This is a trivial conversion: the sizes match, and a // `PartialVector` constructor doesn't specify a scalar type, so // nothing can possibly happen. @@ -173,14 +206,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rows: src_rows, .. }), - ConcreteConstructor::Type( + Constructor::Type(( _, &crate::TypeInner::Matrix { columns: dst_columns, rows: dst_rows, width: dst_width, }, - ), + )), ) if dst_columns == src_columns && dst_rows == src_rows => crate::Expression::As { expr: component, kind: crate::ScalarKind::Float, @@ -195,7 +228,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rows: src_rows, .. }), - ConcreteConstructor::PartialMatrix { + Constructor::PartialMatrix { columns: dst_columns, rows: dst_rows, }, @@ -209,7 +242,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ( &[(component, _span)], Some(&crate::TypeInner::Scalar { .. }), - ConcreteConstructor::PartialVector { size }, + Constructor::PartialVector { size }, ) => crate::Expression::Splat { size, value: component, @@ -223,14 +256,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { width: src_width, .. }), - ConcreteConstructor::Type( + Constructor::Type(( _, &crate::TypeInner::Vector { size, kind: dst_kind, width: dst_width, }, - ), + )), ) if dst_kind == src_kind || dst_width == src_width => crate::Expression::Splat { size, value: component, @@ -243,12 +276,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &crate::TypeInner::Scalar { kind, width } | &crate::TypeInner::Vector { kind, width, .. }, ), - ConcreteConstructor::PartialVector { size }, + Constructor::PartialVector { size }, ) | ( components @ &[_, _, ..], Some(&crate::TypeInner::Scalar { .. } | &crate::TypeInner::Vector { .. }), - ConcreteConstructor::Type(_, &crate::TypeInner::Vector { size, width, kind }), + Constructor::Type((_, &crate::TypeInner::Vector { size, width, kind })), ) => { let inner = crate::TypeInner::Vector { size, kind, width }; let ty = ctx.ensure_type_exists(inner); @@ -260,19 +293,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ( components @ &[_, _, ..], Some(&crate::TypeInner::Scalar { width, .. }), - ConcreteConstructor::PartialMatrix { columns, rows }, + Constructor::PartialMatrix { columns, rows }, ) | ( components @ &[_, _, ..], Some(&crate::TypeInner::Scalar { .. }), - ConcreteConstructor::Type( + Constructor::Type(( _, &crate::TypeInner::Matrix { columns, rows, width, }, - ), + )), ) => { let vec_ty = ctx.ensure_type_exists(crate::TypeInner::Vector { width, @@ -308,19 +341,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ( components @ &[_, _, ..], Some(&crate::TypeInner::Vector { width, .. }), - ConcreteConstructor::PartialMatrix { columns, rows }, + Constructor::PartialMatrix { columns, rows }, ) | ( components @ &[_, _, ..], Some(&crate::TypeInner::Vector { .. }), - ConcreteConstructor::Type( + Constructor::Type(( _, &crate::TypeInner::Matrix { columns, rows, width, }, - ), + )), ) => { let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { columns, @@ -332,7 +365,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } // Array constructor - infer type - (components, _, ConcreteConstructor::PartialArray) => { + (components, _, Constructor::PartialArray) => { let components: Vec<_> = components.iter().map(|c| c.0).collect(); let base = ctx.register_type(components[0])?; @@ -356,10 +389,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ( components, _, - ConcreteConstructor::Type( + Constructor::Type(( ty, &crate::TypeInner::Array { .. } | &crate::TypeInner::Struct { .. }, - ), + )), ) => { let components = components.iter().map(|c| c.0).collect(); crate::Expression::Compose { ty, components } @@ -368,12 +401,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // ERRORS // Bad conversion (type cast) - (&[(_, span)], Some(ty_inner), _) => { + (&[(_, span)], Some(ty_inner), constructor) => { let from_type = ctx.format_typeinner(ty_inner); return Err(Error::BadTypeCast { span, from_type, - to_type: constructor_h.to_error_string(ctx), + to_type: constructor.to_error_string(ctx), }); } @@ -381,7 +414,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ( &[_, ref second, ref rest @ ..], _, - ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { .. }), + Constructor::Type((_, &crate::TypeInner::Scalar { .. })), ) => { let span = second.1.until(&rest.last().unwrap_or(second).1); return Err(Error::UnexpectedComponents(span)); @@ -391,12 +424,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ( &[(_, first_span), _, ..], _, - ConcreteConstructor::Type( + Constructor::Type(( _, &crate::TypeInner::Vector { .. } | &crate::TypeInner::Matrix { .. }, - ) - | ConcreteConstructor::PartialVector { .. } - | ConcreteConstructor::PartialMatrix { .. }, + )) + | Constructor::PartialVector { .. } + | Constructor::PartialMatrix { .. }, ) => { return Err(Error::InvalidConstructorComponentType(first_span, 0)); } @@ -409,17 +442,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(expr) } - /// Build a Naga IR [`Type`] for `constructor` if there is enough - /// information to do so. + /// Build a [`Constructor`] for a WGSL construction expression. /// - /// For `Partial` variants of [`ast::ConstructorType`], we don't know the - /// component type, so in that case we return the appropriate `Partial` - /// variant of [`ConcreteConstructorHandle`]. + /// If `constructor` conveys enough information to determine which Naga [`Type`] + /// we're actually building (i.e., it's not a partial constructor), then + /// ensure the `Type` exists in [`ctx.module`], and return + /// [`Constructor::Type`]. /// - /// But for the other `ConstructorType` variants, we have everything we need - /// to know to actually produce a Naga IR type. In this case we add to/find - /// in [`ctx.module`] a suitable Naga `Type` and return a - /// [`ConcreteConstructorHandle::Type`] value holding its handle. + /// Otherwise, return the [`Constructor`] partial variant corresponding to + /// `constructor`. /// /// [`Type`]: crate::Type /// [`ctx.module`]: ExpressionContext::module @@ -427,21 +458,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &mut self, constructor: &ast::ConstructorType<'source>, ctx: &mut ExpressionContext<'source, '_, 'out>, - ) -> Result> { - let c = match *constructor { + ) -> Result>, Error<'source>> { + let handle = match *constructor { ast::ConstructorType::Scalar { width, kind } => { let ty = ctx.ensure_type_exists(crate::TypeInner::Scalar { width, kind }); - ConcreteConstructorHandle::Type(ty) - } - ast::ConstructorType::PartialVector { size } => { - ConcreteConstructorHandle::PartialVector { size } + Constructor::Type(ty) } + ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size }, ast::ConstructorType::Vector { size, kind, width } => { let ty = ctx.ensure_type_exists(crate::TypeInner::Vector { size, kind, width }); - ConcreteConstructorHandle::Type(ty) + Constructor::Type(ty) } - ast::ConstructorType::PartialMatrix { rows, columns } => { - ConcreteConstructorHandle::PartialMatrix { rows, columns } + ast::ConstructorType::PartialMatrix { columns, rows } => { + Constructor::PartialMatrix { columns, rows } } ast::ConstructorType::Matrix { rows, @@ -453,9 +482,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rows, width, }); - ConcreteConstructorHandle::Type(ty) + Constructor::Type(ty) } - ast::ConstructorType::PartialArray => ConcreteConstructorHandle::PartialArray, + ast::ConstructorType::PartialArray => Constructor::PartialArray, ast::ConstructorType::Array { base, size } => { let base = self.resolve_ast_type(base, &mut ctx.as_global())?; let size = self.array_size(size, &mut ctx.as_global())?; @@ -464,11 +493,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let stride = self.layouter[base].to_stride(); let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride }); - ConcreteConstructorHandle::Type(ty) + Constructor::Type(ty) } - ast::ConstructorType::Type(ty) => ConcreteConstructorHandle::Type(ty), + ast::ConstructorType::Type(ty) => Constructor::Type(ty), }; - Ok(c) + Ok(handle) } }