diff --git a/src/front/wgsl/lower/construction.rs b/src/front/wgsl/lower/construction.rs index 80543d434b..efca0e04dc 100644 --- a/src/front/wgsl/lower/construction.rs +++ b/src/front/wgsl/lower/construction.rs @@ -58,30 +58,6 @@ impl ConcreteConstructorHandle { } } -enum Components<'a> { - None, - One { - component: Handle, - span: Span, - ty_inner: &'a crate::TypeInner, - }, - Many { - components: Vec>, - spans: Vec, - first_component_ty_inner: &'a crate::TypeInner, - }, -} - -impl Components<'_> { - fn into_components_vec(self) -> Vec> { - match self { - Self::None => vec![], - Self::One { component, .. } => vec![component], - Self::Many { components, .. } => components, - } - } -} - impl<'source, 'temp> Lowerer<'source, 'temp> { /// Generate Naga IR for a type constructor expression. /// @@ -106,51 +82,42 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ) -> Result, Error<'source>> { let constructor_h = self.constructor(constructor, ctx)?; - let components = match *components { - [] => Components::None, - [component] => { + let components: Vec<(Handle, Span)> = components + .iter() + .cloned() + .map(|component| { + let lowered = self.expression(component, ctx)?; let span = ctx.ast_expressions.get_span(component); - let component = self.expression(component, ctx)?; - let ty_inner = super::resolve_inner!(ctx, component); + Ok((lowered, span)) + }) + .collect::>()?; + let first_component_ty = match components.first() { + Some(&(first, _)) => Some(super::resolve_inner!(ctx, first)), + None => None, + }; - Components::One { - component, - span, - ty_inner, - } + let constructor = constructor_h.borrow(ctx.module); + + let expr = match (&components[..], first_component_ty, constructor) { + // INTERNAL ERRORS + + // If we have any components at all, we should have had the first + // component's type. + (&[_, ..], None, _) => { + return Err(Error::Internal( + "construction couldn't find first component type", + )); } - [component, ref rest @ ..] => { - let span = ctx.ast_expressions.get_span(component); - let component = self.expression(component, ctx)?; - - let components = std::iter::once(Ok(component)) - .chain( - rest.iter() - .map(|&component| self.expression(component, ctx)), - ) - .collect::>()?; - let spans = std::iter::once(span) - .chain( - rest.iter() - .map(|&component| ctx.ast_expressions.get_span(component)), - ) - .collect(); - - let first_component_ty_inner = super::resolve_inner!(ctx, component); - - Components::Many { - components, - spans, - first_component_ty_inner, - } + (&[], Some(_), _) => { + return Err(Error::Internal( + "construction shouldn't have any first component type", + )); } - }; - let constructor = constructor_h.borrow(ctx.module); + // WELL-FORMED CASES - let expr = match (components, constructor) { // Empty constructor - (Components::None, dst_ty) => match dst_ty { + (&[], None, dst_ty) => match dst_ty { ConcreteConstructor::Type(ty, _) => { return ctx.append_expression(crate::Expression::ZeroValue(ty), span) } @@ -159,11 +126,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Scalar constructor & conversion (scalar -> scalar) ( - Components::One { - component, - ty_inner: &crate::TypeInner::Scalar { .. }, - .. - }, + &[(component, _span)], + Some(&crate::TypeInner::Scalar { .. }), ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { kind, width }), ) => crate::Expression::As { expr: component, @@ -173,11 +137,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Vector conversion (vector -> vector) ( - Components::One { - component, - ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, - .. - }, + &[(component, _span)], + Some(&crate::TypeInner::Vector { size: src_size, .. }), ConcreteConstructor::Type( _, &crate::TypeInner::Vector { @@ -194,31 +155,24 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Vector conversion (vector -> vector) - partial ( - Components::One { - component, - ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, - .. - }, + &[(component, _span)], + Some(&crate::TypeInner::Vector { size: src_size, .. }), ConcreteConstructor::PartialVector { size: dst_size }, ) if dst_size == src_size => { - // This is a trivial conversion: the sizes match, and a Partial - // constructor doesn't specify a scalar type, so nothing can - // possibly happen. + // This is a trivial conversion: the sizes match, and a + // `PartialVector` constructor doesn't specify a scalar type, so + // nothing can possibly happen. return Ok(component); } // Matrix conversion (matrix -> matrix) ( - Components::One { - component, - ty_inner: - &crate::TypeInner::Matrix { - columns: src_columns, - rows: src_rows, - .. - }, + &[(component, _span)], + Some(&crate::TypeInner::Matrix { + columns: src_columns, + rows: src_rows, .. - }, + }), ConcreteConstructor::Type( _, &crate::TypeInner::Matrix { @@ -235,16 +189,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Matrix conversion (matrix -> matrix) - partial ( - Components::One { - component, - ty_inner: - &crate::TypeInner::Matrix { - columns: src_columns, - rows: src_rows, - .. - }, + &[(component, _span)], + Some(&crate::TypeInner::Matrix { + columns: src_columns, + rows: src_rows, .. - }, + }), ConcreteConstructor::PartialMatrix { columns: dst_columns, rows: dst_rows, @@ -258,11 +208,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Vector constructor (splat) - infer type ( - Components::One { - component, - ty_inner: &crate::TypeInner::Scalar { .. }, - .. - }, + &[(component, _span)], + Some(&crate::TypeInner::Scalar { .. }), ConcreteConstructor::PartialVector { size }, ) => crate::Expression::Splat { size, @@ -271,16 +218,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Vector constructor (splat) ( - Components::One { - component, - ty_inner: - &crate::TypeInner::Scalar { - kind: src_kind, - width: src_width, - .. - }, + &[(component, _span)], + Some(&crate::TypeInner::Scalar { + kind: src_kind, + width: src_width, .. - }, + }), ConcreteConstructor::Type( _, &crate::TypeInner::Vector { @@ -296,44 +239,33 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Vector constructor (by elements) ( - Components::Many { - components, - first_component_ty_inner: - &crate::TypeInner::Scalar { kind, width } - | &crate::TypeInner::Vector { kind, width, .. }, - .. - }, + components @ &[_, _, ..], + Some( + &crate::TypeInner::Scalar { kind, width } + | &crate::TypeInner::Vector { kind, width, .. }, + ), ConcreteConstructor::PartialVector { size }, ) | ( - Components::Many { - components, - first_component_ty_inner: - &crate::TypeInner::Scalar { .. } | &crate::TypeInner::Vector { .. }, - .. - }, + components @ &[_, _, ..], + Some(&crate::TypeInner::Scalar { .. } | &crate::TypeInner::Vector { .. }), ConcreteConstructor::Type(_, &crate::TypeInner::Vector { size, width, kind }), ) => { let inner = crate::TypeInner::Vector { size, kind, width }; let ty = ctx.ensure_type_exists(inner); + let components = components.iter().map(|c| c.0).collect(); crate::Expression::Compose { ty, components } } // Matrix constructor (by elements) ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Scalar { width, .. }, - .. - }, + components @ &[_, _, ..], + Some(&crate::TypeInner::Scalar { width, .. }), ConcreteConstructor::PartialMatrix { columns, rows }, ) | ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Scalar { .. }, - .. - }, + components @ &[_, _, ..], + Some(&crate::TypeInner::Scalar { .. }), ConcreteConstructor::Type( _, &crate::TypeInner::Matrix { @@ -349,13 +281,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { size: rows, }); + // Constructing a matrix by elements first constructs column + // vectors, and then the matrix from those column vectors. let components = components .chunks(rows as usize) .map(|vec_components| { + let vec_components = vec_components.iter().map(|c| c.0).collect(); ctx.append_expression( crate::Expression::Compose { ty: vec_ty, - components: Vec::from(vec_components), + components: vec_components, }, Default::default(), ) @@ -372,19 +307,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Matrix constructor (by columns) ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Vector { width, .. }, - .. - }, + components @ &[_, _, ..], + Some(&crate::TypeInner::Vector { width, .. }), ConcreteConstructor::PartialMatrix { columns, rows }, ) | ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Vector { .. }, - .. - }, + components @ &[_, _, ..], + Some(&crate::TypeInner::Vector { .. }), ConcreteConstructor::Type( _, &crate::TypeInner::Matrix { @@ -399,12 +328,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rows, width, }); + let components = components.iter().map(|c| c.0).collect(); crate::Expression::Compose { ty, components } } // Array constructor - infer type - (components, ConcreteConstructor::PartialArray) => { - let components = components.into_components_vec(); + (components, _, ConcreteConstructor::PartialArray) => { + let components: Vec<_> = components.iter().map(|c| c.0).collect(); let base = ctx.register_type(components[0])?; @@ -426,19 +356,20 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Array or Struct constructor ( components, + _, ConcreteConstructor::Type( ty, &crate::TypeInner::Array { .. } | &crate::TypeInner::Struct { .. }, ), ) => { - let components = components.into_components_vec(); + let components = components.iter().map(|c| c.0).collect(); crate::Expression::Compose { ty, components } } // ERRORS // Bad conversion (type cast) - (Components::One { span, ty_inner, .. }, _) => { + (&[(_, span)], Some(ty_inner), _) => { let from_type = ctx.format_typeinner(ty_inner); return Err(Error::BadTypeCast { span, @@ -449,16 +380,18 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Too many parameters for scalar constructor ( - Components::Many { spans, .. }, + &[_, ref second, ref rest @ ..], + _, ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { .. }), ) => { - let span = spans[1].until(spans.last().unwrap()); + let span = second.1.until(&rest.last().unwrap_or(second).1); return Err(Error::UnexpectedComponents(span)); } // Parameters are of the wrong type for vector or matrix constructor ( - Components::Many { spans, .. }, + &[(_, first_span), _, ..], + _, ConcreteConstructor::Type( _, &crate::TypeInner::Vector { .. } | &crate::TypeInner::Matrix { .. }, @@ -466,7 +399,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { | ConcreteConstructor::PartialVector { .. } | ConcreteConstructor::PartialMatrix { .. }, ) => { - return Err(Error::InvalidConstructorComponentType(spans[0], 0)); + return Err(Error::InvalidConstructorComponentType(first_span, 0)); } // Other types can't be constructed