From 81684a25d0ee9cb8b4761cc9253c90dc4ddf6c5b Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Thu, 19 Oct 2023 19:09:58 +0200 Subject: [PATCH] [wgsl-in] consolidate type resolution logic in a few macros --- src/front/wgsl/lower/construction.rs | 6 +- src/front/wgsl/lower/mod.rs | 102 +++++++++++++-------------- 2 files changed, 52 insertions(+), 56 deletions(-) diff --git a/src/front/wgsl/lower/construction.rs b/src/front/wgsl/lower/construction.rs index 59d7b17435..ec3a338706 100644 --- a/src/front/wgsl/lower/construction.rs +++ b/src/front/wgsl/lower/construction.rs @@ -152,8 +152,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { [component] => { let span = ctx.ast_expressions.get_span(component); let component = self.expression(component, ctx.reborrow())?; - ctx.grow_types(component)?; - let ty = &ctx.typifier()[component]; + let ty = super::resolve!(ctx, component); ComponentsHandle::One { component, @@ -178,8 +177,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ) .collect(); - ctx.grow_types(component)?; - let ty = &ctx.typifier()[component]; + let ty = super::resolve!(ctx, component); ComponentsHandle::Many { components, diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 6ff520e952..76476081e8 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -13,6 +13,32 @@ use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span}; mod construction; +macro_rules! resolve { + ($ctx:ident, $expr:expr) => {{ + $ctx.grow_types($expr)?; + &$ctx.typifier()[$expr] + }}; +} +pub(super) use resolve; + +macro_rules! resolve_inner { + ($ctx:ident, $expr:expr) => {{ + $ctx.grow_types($expr)?; + $ctx.typifier()[$expr].inner_with(&$ctx.module.types) + }}; +} + +macro_rules! resolve_inner_x2 { + ($ctx:ident, $left:expr, $right:expr) => {{ + $ctx.grow_types($left)?; + $ctx.grow_types($right)?; + ( + $ctx.typifier()[$left].inner_with(&$ctx.module.types), + $ctx.typifier()[$right].inner_with(&$ctx.module.types), + ) + }}; +} + /// State for constructing a `crate::Module`. pub struct GlobalContext<'source, 'temp, 'out> { /// The `TranslationUnit`'s expressions arena. @@ -460,15 +486,13 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { /// Determine the type of `handle`, and add it to the module's arena. /// - /// If you just need a `TypeInner` for `handle`'s type, use - /// [`grow_types`] and [`resolved_inner`] instead. This function + /// If you just need a `TypeInner` for `handle`'s type, use the + /// [`resolve_inner!`] macro instead. This function /// should only be used when the type of `handle` needs to appear /// in the module's final `Arena`, for example, if you're /// creating a [`LocalVariable`] whose type is inferred from its /// initializer. /// - /// [`grow_types`]: Self::grow_types - /// [`resolved_inner`]: Self::resolved_inner /// [`LocalVariable`]: crate::LocalVariable fn register_type( &mut self, @@ -498,12 +522,11 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { /// return a shared reference to the resulting `TypeResolution`: /// the shared reference would extend the mutable borrow, and you /// wouldn't be able to use `self` for anything else. Instead, you - /// should call `grow_types` to cover the handles you need, and - /// then use `self.typifier[handle]` or - /// [`self.resolved_inner(handle)`] to get at their resolutions. + /// should use [`register_type`] or one of [`resolve!`], + /// [`resolve_inner!`] or [`resolve_inner_x2!`]. /// /// [`self.typifier`]: ExpressionContext::typifier - /// [`self.resolved_inner(handle)`]: ExpressionContext::resolved_inner + /// [`register_type`]: Self::register_type /// [`Typifier`]: Typifier fn grow_types( &mut self, @@ -533,17 +556,12 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { Ok(self) } - fn resolved_inner(&self, handle: Handle) -> &crate::TypeInner { - self.typifier()[handle].inner_with(&self.module.types) - } - fn image_data( &mut self, image: Handle, span: Span, ) -> Result<(crate::ImageClass, bool), Error<'source>> { - self.grow_types(image)?; - match *self.resolved_inner(image) { + match *resolve_inner!(self, image) { crate::TypeInner::Image { class, arrayed, .. } => Ok((class, arrayed)), _ => Err(Error::BadTexture(span)), } @@ -584,9 +602,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { | crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo ) { - self.grow_types(*left)?.grow_types(*right)?; - - match (self.resolved_inner(*left), self.resolved_inner(*right)) { + match resolve_inner_x2!(self, *left, *right) { (&crate::TypeInner::Vector { size, .. }, &crate::TypeInner::Scalar { .. }) => { *right = self.append_expression( crate::Expression::Splat { @@ -1146,11 +1162,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let initializer = match v.init { Some(init) => { - let initializer = - self.expression(init, ctx.as_expression(block, &mut emitter))?; - ctx.as_expression(block, &mut emitter) - .grow_types(initializer)?; - Some(initializer) + Some(self.expression(init, ctx.as_expression(block, &mut emitter))?) } None => None, }; @@ -1161,8 +1173,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let ty = match (explicit_ty, initializer) { (Some(explicit), Some(initializer)) => { - let ctx = ctx.as_expression(block, &mut emitter); - let initializer_ty = ctx.resolved_inner(initializer); + let mut ctx = ctx.as_expression(block, &mut emitter); + let initializer_ty = resolve_inner!(ctx, initializer); if !ctx.module.types[explicit] .inner .equivalent(initializer_ty, &ctx.module.types) @@ -1266,9 +1278,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut ectx = ctx.as_expression(block, &mut emitter); let selector = self.expression(selector, ectx.reborrow())?; - ectx.grow_types(selector)?; let uint = - ectx.resolved_inner(selector).scalar_kind() == Some(crate::ScalarKind::Uint); + resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint); block.extend(emitter.finish(ctx.naga_expressions)); let cases = cases @@ -1407,8 +1418,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { self.expression_for_reference(value, ctx.as_expression(block, &mut emitter))?; let mut ectx = ctx.as_expression(block, &mut emitter); - ectx.grow_types(reference.handle)?; - let (kind, width) = match *ectx.resolved_inner(reference.handle) { + let (kind, width) = match *resolve_inner!(ectx, reference.handle) { crate::TypeInner::ValuePointer { size: None, kind, @@ -1553,8 +1563,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // The pointer we dereference must be loaded. let pointer = self.expression(expr, ctx.reborrow())?; - ctx.grow_types(pointer)?; - if ctx.resolved_inner(pointer).pointer_space().is_none() { + if resolve_inner!(ctx, pointer).pointer_space().is_none() { return Err(Error::NotPointer(span)); } @@ -1583,9 +1592,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let expr = self.expression_for_reference(base, ctx.reborrow())?; let index = self.expression(index, ctx.reborrow())?; - ctx.grow_types(expr.handle)?; - let wgsl_pointer = - ctx.resolved_inner(expr.handle).pointer_space().is_some() && !expr.is_reference; + let wgsl_pointer = resolve_inner!(ctx, expr.handle).pointer_space().is_some() + && !expr.is_reference; if wgsl_pointer { return Err(Error::Pointer( @@ -1618,9 +1626,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { is_reference, } = self.expression_for_reference(base, ctx.reborrow())?; - ctx.grow_types(handle)?; let temp_inner; - let (composite, wgsl_pointer) = match *ctx.resolved_inner(handle) { + let (composite, wgsl_pointer) = match *resolve_inner!(ctx, handle) { crate::TypeInner::Pointer { base, .. } => { (&ctx.module.types[base].inner, !is_reference) } @@ -1707,8 +1714,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { crate::TypeInner::Scalar { kind, .. } => kind, crate::TypeInner::Vector { kind, .. } => kind, _ => { - ctx.grow_types(expr)?; - let ty = &ctx.typifier()[expr]; + let ty = resolve!(ctx, expr); return Err(Error::BadTypeCast { from_type: ctx.format_type_resolution(ty), span: ty_span, @@ -1814,9 +1820,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fun, crate::RelationalFunction::All | crate::RelationalFunction::Any ) && { - ctx.grow_types(argument)?; matches!( - ctx.resolved_inner(argument), + resolve_inner!(ctx, argument), &crate::TypeInner::Scalar { kind: crate::ScalarKind::Bool, .. @@ -1859,8 +1864,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; if fun == crate::MathFunction::Modf || fun == crate::MathFunction::Frexp { - ctx.grow_types(arg)?; - if let Some((size, width)) = match *ctx.resolved_inner(arg) { + if let Some((size, width)) = match *resolve_inner!(ctx, arg) { crate::TypeInner::Scalar { width, .. } => Some((None, width)), crate::TypeInner::Vector { size, width, .. } => { Some((Some(size), width)) @@ -2005,11 +2009,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let value = args.next()?; let value_span = ctx.ast_expressions.get_span(value); let value = self.expression(value, ctx.reborrow())?; - ctx.grow_types(value)?; args.finish()?; - let expression = match *ctx.resolved_inner(value) { + let expression = match *resolve_inner!(ctx, value) { crate::TypeInner::Scalar { kind, width } => { crate::Expression::AtomicResult { ty: ctx.module.generate_predeclared_type( @@ -2061,8 +2064,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; let pointer = self.expression(expr, ctx.reborrow())?; - ctx.grow_types(pointer)?; - let result_ty = match *ctx.resolved_inner(pointer) { + let result_ty = match *resolve_inner!(ctx, pointer) { crate::TypeInner::Pointer { base, space: crate::AddressSpace::WorkGroup, @@ -2083,7 +2085,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, ); - ctx.grow_types(pointer)?; return Ok(Some(result)); } "textureStore" => { @@ -2280,8 +2281,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let span = ctx.ast_expressions.get_span(expr); let pointer = self.expression(expr, ctx.reborrow())?; - ctx.grow_types(pointer)?; - match *ctx.resolved_inner(pointer) { + match *resolve_inner!(ctx, pointer) { crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { crate::TypeInner::Atomic { .. } => Ok(pointer), ref other => { @@ -2361,8 +2361,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let lowered_image_or_component = self.expression(image_or_component, ctx.reborrow())?; - ctx.grow_types(lowered_image_or_component)?; - match *ctx.resolved_inner(lowered_image_or_component) { + match *resolve_inner!(ctx, lowered_image_or_component) { crate::TypeInner::Image { class: crate::ImageClass::Depth { .. }, .. @@ -2683,8 +2682,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let span = ctx.ast_expressions.get_span(expr); let pointer = self.expression(expr, ctx.reborrow())?; - ctx.grow_types(pointer)?; - match *ctx.resolved_inner(pointer) { + match *resolve_inner!(ctx, pointer) { crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { crate::TypeInner::RayQuery => Ok(pointer), ref other => {