From b5c8d9bcf970939fb1c5e09696867d1b96b449b4 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 19 Oct 2023 16:52:23 -0700 Subject: [PATCH 1/4] [wgsl-in] Delete front::wgsl::parse::ExpressionContext::reborrow. Remove `ExpressionContext::reborrow` in favor of Rust's automatic reborrowing of `&mut` references. Use lifetime elision in more places. --- src/front/wgsl/parse/mod.rs | 285 +++++++++++++++++------------------- 1 file changed, 135 insertions(+), 150 deletions(-) diff --git a/src/front/wgsl/parse/mod.rs b/src/front/wgsl/parse/mod.rs index 23398e6116..431257aff8 100644 --- a/src/front/wgsl/parse/mod.rs +++ b/src/front/wgsl/parse/mod.rs @@ -55,31 +55,21 @@ struct ExpressionContext<'input, 'temp, 'out> { } impl<'a> ExpressionContext<'a, '_, '_> { - fn reborrow(&mut self) -> ExpressionContext<'a, '_, '_> { - ExpressionContext { - expressions: self.expressions, - types: self.types, - local_table: self.local_table, - locals: self.locals, - unresolved: self.unresolved, - } - } - fn parse_binary_op( &mut self, lexer: &mut Lexer<'a>, classifier: impl Fn(Token<'a>) -> Option, mut parser: impl FnMut( &mut Lexer<'a>, - ExpressionContext<'a, '_, '_>, + &mut Self, ) -> Result>, Error<'a>>, ) -> Result>, Error<'a>> { let start = lexer.start_byte_offset(); - let mut accumulator = parser(lexer, self.reborrow())?; + let mut accumulator = parser(lexer, self)?; while let Some(op) = classifier(lexer.peek().0) { let _ = lexer.next(); let left = accumulator; - let right = parser(lexer, self.reborrow())?; + let right = parser(lexer, self)?; accumulator = self.expressions.append( ast::Expression::Binary { op, left, right }, lexer.span_from(start), @@ -157,13 +147,13 @@ impl<'a> BindingParser<'a> { lexer: &mut Lexer<'a>, name: &'a str, name_span: Span, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<(), Error<'a>> { match name { "location" => { lexer.expect(Token::Paren('('))?; self.location - .set(parser.general_expression(lexer, ctx.reborrow())?, name_span)?; + .set(parser.general_expression(lexer, ctx)?, name_span)?; lexer.expect(Token::Paren(')'))?; } "builtin" => { @@ -258,14 +248,14 @@ impl Parser { fn switch_value<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result, Error<'a>> { if let Token::Word("default") = lexer.peek().0 { let _ = lexer.next(); return Ok(ast::SwitchValue::Default); } - let expr = self.general_expression(lexer, ctx.reborrow())?; + let expr = self.general_expression(lexer, ctx)?; Ok(ast::SwitchValue::Expr(expr)) } @@ -285,7 +275,7 @@ impl Parser { lexer: &mut Lexer<'a>, word: &'a str, span: Span, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { if let Some((kind, width)) = conv::get_scalar_type(word) { return Ok(Some(ast::ConstructorType::Scalar { kind, width })); @@ -509,9 +499,9 @@ impl Parser { } (Token::Paren('<'), ast::ConstructorType::PartialArray) => { lexer.expect_generic_paren('<')?; - let base = self.type_decl(lexer, ctx.reborrow())?; + let base = self.type_decl(lexer, ctx)?; let size = if lexer.skip(Token::Separator(',')) { - let expr = self.unary_expression(lexer, ctx.reborrow())?; + let expr = self.unary_expression(lexer, ctx)?; ast::ArraySize::Constant(expr) } else { ast::ArraySize::Dynamic @@ -528,7 +518,7 @@ impl Parser { fn arguments<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>>, Error<'a>> { lexer.open_arguments()?; let mut arguments = Vec::new(); @@ -540,7 +530,7 @@ impl Parser { } else if lexer.skip(Token::Paren(')')) { break; } - let arg = self.general_expression(lexer, ctx.reborrow())?; + let arg = self.general_expression(lexer, ctx)?; arguments.push(arg); } @@ -554,7 +544,7 @@ impl Parser { lexer: &mut Lexer<'a>, name: &'a str, name_span: Span, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { assert!(self.rules.last().is_some()); @@ -563,12 +553,12 @@ impl Parser { "bitcast" => { lexer.expect_generic_paren('<')?; let start = lexer.start_byte_offset(); - let to = self.type_decl(lexer, ctx.reborrow())?; + let to = self.type_decl(lexer, ctx)?; let span = lexer.span_from(start); lexer.expect_generic_paren('>')?; lexer.open_arguments()?; - let expr = self.general_expression(lexer, ctx.reborrow())?; + let expr = self.general_expression(lexer, ctx)?; lexer.close_arguments()?; ast::Expression::Bitcast { @@ -579,7 +569,7 @@ impl Parser { } // everything else must be handled later, since they can be hidden by user-defined functions. _ => { - let arguments = self.arguments(lexer, ctx.reborrow())?; + let arguments = self.arguments(lexer, ctx)?; ctx.unresolved.insert(ast::Dependency { ident: name, usage: name_span, @@ -603,7 +593,7 @@ impl Parser { &mut self, name: &'a str, name_span: Span, - ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> ast::IdentExpr<'a> { match ctx.local_table.lookup(name) { Some(&local) => ast::IdentExpr::Local(local), @@ -620,14 +610,14 @@ impl Parser { fn primary_expression<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { self.push_rule_span(Rule::PrimaryExpr, lexer); let expr = match lexer.peek() { (Token::Paren('('), _) => { let _ = lexer.next(); - let expr = self.general_expression(lexer, ctx.reborrow())?; + let expr = self.general_expression(lexer, ctx)?; lexer.expect(Token::Paren(')'))?; self.pop_rule_span(lexer); return Ok(expr); @@ -661,9 +651,9 @@ impl Parser { let start = lexer.start_byte_offset(); let _ = lexer.next(); - if let Some(ty) = self.constructor_type(lexer, word, span, ctx.reborrow())? { + if let Some(ty) = self.constructor_type(lexer, word, span, ctx)? { let ty_span = lexer.span_from(start); - let components = self.arguments(lexer, ctx.reborrow())?; + let components = self.arguments(lexer, ctx)?; ast::Expression::Construct { ty, ty_span, @@ -676,7 +666,7 @@ impl Parser { self.pop_rule_span(lexer); return self.function_call(lexer, word, span, ctx); } else { - let ident = self.ident_expr(word, span, ctx.reborrow()); + let ident = self.ident_expr(word, span, ctx); ast::Expression::Ident(ident) } } @@ -692,7 +682,7 @@ impl Parser { &mut self, span_start: usize, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, expr: Handle>, ) -> Result>, Error<'a>> { let mut expr = expr; @@ -707,7 +697,7 @@ impl Parser { } Token::Paren('[') => { let _ = lexer.next(); - let index = self.general_expression(lexer, ctx.reborrow())?; + let index = self.general_expression(lexer, ctx)?; lexer.expect(Token::Paren(']'))?; ast::Expression::Index { base: expr, index } @@ -726,14 +716,14 @@ impl Parser { fn unary_expression<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { self.push_rule_span(Rule::UnaryExpr, lexer); //TODO: refactor this to avoid backing up let expr = match lexer.peek().0 { Token::Operation('-') => { let _ = lexer.next(); - let expr = self.unary_expression(lexer, ctx.reborrow())?; + let expr = self.unary_expression(lexer, ctx)?; let expr = ast::Expression::Unary { op: crate::UnaryOperator::Negate, expr, @@ -743,7 +733,7 @@ impl Parser { } Token::Operation('!') => { let _ = lexer.next(); - let expr = self.unary_expression(lexer, ctx.reborrow())?; + let expr = self.unary_expression(lexer, ctx)?; let expr = ast::Expression::Unary { op: crate::UnaryOperator::LogicalNot, expr, @@ -753,7 +743,7 @@ impl Parser { } Token::Operation('~') => { let _ = lexer.next(); - let expr = self.unary_expression(lexer, ctx.reborrow())?; + let expr = self.unary_expression(lexer, ctx)?; let expr = ast::Expression::Unary { op: crate::UnaryOperator::BitwiseNot, expr, @@ -763,19 +753,19 @@ impl Parser { } Token::Operation('*') => { let _ = lexer.next(); - let expr = self.unary_expression(lexer, ctx.reborrow())?; + let expr = self.unary_expression(lexer, ctx)?; let expr = ast::Expression::Deref(expr); let span = self.peek_rule_span(lexer); ctx.expressions.append(expr, span) } Token::Operation('&') => { let _ = lexer.next(); - let expr = self.unary_expression(lexer, ctx.reborrow())?; + let expr = self.unary_expression(lexer, ctx)?; let expr = ast::Expression::AddrOf(expr); let span = self.peek_rule_span(lexer); ctx.expressions.append(expr, span) } - _ => self.singular_expression(lexer, ctx.reborrow())?, + _ => self.singular_expression(lexer, ctx)?, }; self.pop_rule_span(lexer); @@ -786,12 +776,12 @@ impl Parser { fn singular_expression<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { let start = lexer.start_byte_offset(); self.push_rule_span(Rule::SingularExpr, lexer); - let primary_expr = self.primary_expression(lexer, ctx.reborrow())?; - let singular_expr = self.postfix(start, lexer, ctx.reborrow(), primary_expr)?; + let primary_expr = self.primary_expression(lexer, ctx)?; + let singular_expr = self.postfix(start, lexer, ctx, primary_expr)?; self.pop_rule_span(lexer); Ok(singular_expr) @@ -800,7 +790,7 @@ impl Parser { fn equality_expression<'a>( &mut self, lexer: &mut Lexer<'a>, - mut context: ExpressionContext<'a, '_, '_>, + context: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { // equality_expression context.parse_binary_op( @@ -811,7 +801,7 @@ impl Parser { _ => None, }, // relational_expression - |lexer, mut context| { + |lexer, context| { context.parse_binary_op( lexer, |token| match token { @@ -822,7 +812,7 @@ impl Parser { _ => None, }, // shift_expression - |lexer, mut context| { + |lexer, context| { context.parse_binary_op( lexer, |token| match token { @@ -835,7 +825,7 @@ impl Parser { _ => None, }, // additive_expression - |lexer, mut context| { + |lexer, context| { context.parse_binary_op( lexer, |token| match token { @@ -846,7 +836,7 @@ impl Parser { _ => None, }, // multiplicative_expression - |lexer, mut context| { + |lexer, context| { context.parse_binary_op( lexer, |token| match token { @@ -876,16 +866,16 @@ impl Parser { fn general_expression<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { - self.general_expression_with_span(lexer, ctx.reborrow()) + self.general_expression_with_span(lexer, ctx) .map(|(expr, _)| expr) } fn general_expression_with_span<'a>( &mut self, lexer: &mut Lexer<'a>, - mut context: ExpressionContext<'a, '_, '_>, + context: &mut ExpressionContext<'a, '_, '_>, ) -> Result<(Handle>, Span), Error<'a>> { self.push_rule_span(Rule::GeneralExpr, lexer); // logical_or_expression @@ -896,7 +886,7 @@ impl Parser { _ => None, }, // logical_and_expression - |lexer, mut context| { + |lexer, context| { context.parse_binary_op( lexer, |token| match token { @@ -904,7 +894,7 @@ impl Parser { _ => None, }, // inclusive_or_expression - |lexer, mut context| { + |lexer, context| { context.parse_binary_op( lexer, |token| match token { @@ -912,7 +902,7 @@ impl Parser { _ => None, }, // exclusive_or_expression - |lexer, mut context| { + |lexer, context| { context.parse_binary_op( lexer, |token| match token { @@ -922,7 +912,7 @@ impl Parser { _ => None, }, // and_expression - |lexer, mut context| { + |lexer, context| { context.parse_binary_op( lexer, |token| match token { @@ -949,7 +939,7 @@ impl Parser { fn variable_decl<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result, Error<'a>> { self.push_rule_span(Rule::VariableDecl, lexer); let mut space = crate::AddressSpace::Handle; @@ -972,10 +962,10 @@ impl Parser { } let name = lexer.next_ident()?; lexer.expect(Token::Separator(':'))?; - let ty = self.type_decl(lexer, ctx.reborrow())?; + let ty = self.type_decl(lexer, ctx)?; let init = if lexer.skip(Token::Operation('=')) { - let handle = self.general_expression(lexer, ctx.reborrow())?; + let handle = self.general_expression(lexer, ctx)?; Some(handle) } else { None @@ -995,7 +985,7 @@ impl Parser { fn struct_body<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { let mut members = Vec::new(); @@ -1015,19 +1005,17 @@ impl Parser { match lexer.next_ident_with_span()? { ("size", name_span) => { lexer.expect(Token::Paren('('))?; - let expr = self.general_expression(lexer, ctx.reborrow())?; + let expr = self.general_expression(lexer, ctx)?; lexer.expect(Token::Paren(')'))?; size.set(expr, name_span)?; } ("align", name_span) => { lexer.expect(Token::Paren('('))?; - let expr = self.general_expression(lexer, ctx.reborrow())?; + let expr = self.general_expression(lexer, ctx)?; lexer.expect(Token::Paren(')'))?; align.set(expr, name_span)?; } - (word, word_span) => { - bind_parser.parse(self, lexer, word, word_span, ctx.reborrow())? - } + (word, word_span) => bind_parser.parse(self, lexer, word, word_span, ctx)?, } } @@ -1036,7 +1024,7 @@ impl Parser { let name = lexer.next_ident()?; lexer.expect(Token::Separator(':'))?; - let ty = self.type_decl(lexer, ctx.reborrow())?; + let ty = self.type_decl(lexer, ctx)?; ready = lexer.skip(Token::Separator(',')); members.push(ast::StructMember { @@ -1072,7 +1060,7 @@ impl Parser { &mut self, lexer: &mut Lexer<'a>, word: &'a str, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { if let Some((kind, width)) = conv::get_scalar_type(word) { return Ok(Some(ast::Type::Scalar { kind, width })); @@ -1242,9 +1230,9 @@ impl Parser { } "array" => { lexer.expect_generic_paren('<')?; - let base = self.type_decl(lexer, ctx.reborrow())?; + let base = self.type_decl(lexer, ctx)?; let size = if lexer.skip(Token::Separator(',')) { - let size = self.unary_expression(lexer, ctx.reborrow())?; + let size = self.unary_expression(lexer, ctx)?; ast::ArraySize::Constant(size) } else { ast::ArraySize::Dynamic @@ -1255,9 +1243,9 @@ impl Parser { } "binding_array" => { lexer.expect_generic_paren('<')?; - let base = self.type_decl(lexer, ctx.reborrow())?; + let base = self.type_decl(lexer, ctx)?; let size = if lexer.skip(Token::Separator(',')) { - let size = self.unary_expression(lexer, ctx.reborrow())?; + let size = self.unary_expression(lexer, ctx)?; ast::ArraySize::Constant(size) } else { ast::ArraySize::Dynamic @@ -1439,13 +1427,13 @@ impl Parser { fn type_decl<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { self.push_rule_span(Rule::TypeDecl, lexer); let (name, span) = lexer.next_ident_with_span()?; - let ty = match self.type_decl_impl(lexer, name, ctx.reborrow())? { + let ty = match self.type_decl_impl(lexer, name, ctx)? { Some(ty) => ty, None => { ctx.unresolved.insert(ast::Dependency { @@ -1462,11 +1450,11 @@ impl Parser { Ok(handle) } - fn assignment_op_and_rhs<'a, 'out>( + fn assignment_op_and_rhs<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, 'out>, - block: &'out mut ast::Block<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, target: Handle>, span_start: usize, ) -> Result<(), Error<'a>> { @@ -1475,7 +1463,7 @@ impl Parser { let op = lexer.next(); let (op, value) = match op { (Token::Operation('='), _) => { - let value = self.general_expression(lexer, ctx.reborrow())?; + let value = self.general_expression(lexer, ctx)?; (None, value) } (Token::AssignmentOperation(c), _) => { @@ -1494,7 +1482,7 @@ impl Parser { _ => unreachable!(), }; - let value = self.general_expression(lexer, ctx.reborrow())?; + let value = self.general_expression(lexer, ctx)?; (Some(op), value) } token @ (Token::IncrementOperation | Token::DecrementOperation, _) => { @@ -1523,27 +1511,27 @@ impl Parser { } /// Parse an assignment statement (will also parse increment and decrement statements) - fn assignment_statement<'a, 'out>( + fn assignment_statement<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, 'out>, - block: &'out mut ast::Block<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, ) -> Result<(), Error<'a>> { let span_start = lexer.start_byte_offset(); - let target = self.general_expression(lexer, ctx.reborrow())?; + let target = self.general_expression(lexer, ctx)?; self.assignment_op_and_rhs(lexer, ctx, block, target, span_start) } /// Parse a function call statement. /// Expects `ident` to be consumed (not in the lexer). - fn function_statement<'a, 'out>( + fn function_statement<'a>( &mut self, lexer: &mut Lexer<'a>, ident: &'a str, ident_span: Span, span_start: usize, - mut context: ExpressionContext<'a, '_, 'out>, - block: &'out mut ast::Block<'a>, + context: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, ) -> Result<(), Error<'a>> { self.push_rule_span(Rule::SingularExpr, lexer); @@ -1551,7 +1539,7 @@ impl Parser { ident, usage: ident_span, }); - let arguments = self.arguments(lexer, context.reborrow())?; + let arguments = self.arguments(lexer, context)?; let span = lexer.span_from(span_start); block.stmts.push(ast::Statement { @@ -1570,11 +1558,11 @@ impl Parser { Ok(()) } - fn function_call_or_assignment_statement<'a, 'out>( + fn function_call_or_assignment_statement<'a>( &mut self, lexer: &mut Lexer<'a>, - mut context: ExpressionContext<'a, '_, 'out>, - block: &'out mut ast::Block<'a>, + context: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, ) -> Result<(), Error<'a>> { let span_start = lexer.start_byte_offset(); match lexer.peek() { @@ -1583,29 +1571,24 @@ impl Parser { let cloned = lexer.clone(); let _ = lexer.next(); match lexer.peek() { - (Token::Paren('('), _) => self.function_statement( - lexer, - name, - span, - span_start, - context.reborrow(), - block, - ), + (Token::Paren('('), _) => { + self.function_statement(lexer, name, span, span_start, context, block) + } _ => { *lexer = cloned; - self.assignment_statement(lexer, context.reborrow(), block) + self.assignment_statement(lexer, context, block) } } } - _ => self.assignment_statement(lexer, context.reborrow(), block), + _ => self.assignment_statement(lexer, context, block), } } - fn statement<'a, 'out>( + fn statement<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, 'out>, - block: &'out mut ast::Block<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, ) -> Result<(), Error<'a>> { self.push_rule_span(Rule::Statement, lexer); match lexer.peek() { @@ -1615,7 +1598,7 @@ impl Parser { return Ok(()); } (Token::Paren('{'), _) => { - let (inner, span) = self.block(lexer, ctx.reborrow())?; + let (inner, span) = self.block(lexer, ctx)?; block.stmts.push(ast::Statement { kind: ast::StatementKind::Block(inner), span, @@ -1628,7 +1611,7 @@ impl Parser { "_" => { let _ = lexer.next(); lexer.expect(Token::Operation('='))?; - let expr = self.general_expression(lexer, ctx.reborrow())?; + let expr = self.general_expression(lexer, ctx)?; lexer.expect(Token::Separator(';'))?; ast::StatementKind::Ignore(expr) @@ -1638,13 +1621,13 @@ impl Parser { let name = lexer.next_ident()?; let given_ty = if lexer.skip(Token::Separator(':')) { - let ty = self.type_decl(lexer, ctx.reborrow())?; + let ty = self.type_decl(lexer, ctx)?; Some(ty) } else { None }; lexer.expect(Token::Operation('='))?; - let expr_id = self.general_expression(lexer, ctx.reborrow())?; + let expr_id = self.general_expression(lexer, ctx)?; lexer.expect(Token::Separator(';'))?; let handle = ctx.declare_local(name)?; @@ -1660,14 +1643,14 @@ impl Parser { let name = lexer.next_ident()?; let ty = if lexer.skip(Token::Separator(':')) { - let ty = self.type_decl(lexer, ctx.reborrow())?; + let ty = self.type_decl(lexer, ctx)?; Some(ty) } else { None }; let init = if lexer.skip(Token::Operation('=')) { - let init = self.general_expression(lexer, ctx.reborrow())?; + let init = self.general_expression(lexer, ctx)?; Some(init) } else { None @@ -1686,7 +1669,7 @@ impl Parser { "return" => { let _ = lexer.next(); let value = if lexer.peek().0 != Token::Separator(';') { - let handle = self.general_expression(lexer, ctx.reborrow())?; + let handle = self.general_expression(lexer, ctx)?; Some(handle) } else { None @@ -1696,9 +1679,9 @@ impl Parser { } "if" => { let _ = lexer.next(); - let condition = self.general_expression(lexer, ctx.reborrow())?; + let condition = self.general_expression(lexer, ctx)?; - let accept = self.block(lexer, ctx.reborrow())?.0; + let accept = self.block(lexer, ctx)?.0; let mut elsif_stack = Vec::new(); let mut elseif_span_start = lexer.start_byte_offset(); @@ -1709,12 +1692,12 @@ impl Parser { if !lexer.skip(Token::Word("if")) { // ... else { ... } - break self.block(lexer, ctx.reborrow())?.0; + break self.block(lexer, ctx)?.0; } // ... else if (...) { ... } - let other_condition = self.general_expression(lexer, ctx.reborrow())?; - let other_block = self.block(lexer, ctx.reborrow())?; + let other_condition = self.general_expression(lexer, ctx)?; + let other_block = self.block(lexer, ctx)?; elsif_stack.push((elseif_span_start, other_condition, other_block)); elseif_span_start = lexer.start_byte_offset(); }; @@ -1745,7 +1728,7 @@ impl Parser { } "switch" => { let _ = lexer.next(); - let selector = self.general_expression(lexer, ctx.reborrow())?; + let selector = self.general_expression(lexer, ctx)?; lexer.expect(Token::Paren('{'))?; let mut cases = Vec::new(); @@ -1755,7 +1738,7 @@ impl Parser { (Token::Word("case"), _) => { // parse a list of values let value = loop { - let value = self.switch_value(lexer, ctx.reborrow())?; + let value = self.switch_value(lexer, ctx)?; if lexer.skip(Token::Separator(',')) { if lexer.skip(Token::Separator(':')) { break value; @@ -1771,7 +1754,7 @@ impl Parser { }); }; - let body = self.block(lexer, ctx.reborrow())?.0; + let body = self.block(lexer, ctx)?.0; cases.push(ast::SwitchCase { value, @@ -1781,7 +1764,7 @@ impl Parser { } (Token::Word("default"), _) => { lexer.skip(Token::Separator(':')); - let body = self.block(lexer, ctx.reborrow())?.0; + let body = self.block(lexer, ctx)?.0; cases.push(ast::SwitchCase { value: ast::SwitchValue::Default, body, @@ -1797,13 +1780,13 @@ impl Parser { ast::StatementKind::Switch { selector, cases } } - "loop" => self.r#loop(lexer, ctx.reborrow())?, + "loop" => self.r#loop(lexer, ctx)?, "while" => { let _ = lexer.next(); let mut body = ast::Block::default(); let (condition, span) = lexer.capture_span(|lexer| { - let condition = self.general_expression(lexer, ctx.reborrow())?; + let condition = self.general_expression(lexer, ctx)?; Ok(condition) })?; let mut reject = ast::Block::default(); @@ -1821,7 +1804,7 @@ impl Parser { span, }); - let (block, span) = self.block(lexer, ctx.reborrow())?; + let (block, span) = self.block(lexer, ctx)?; body.stmts.push(ast::Statement { kind: ast::StatementKind::Block(block), span, @@ -1841,9 +1824,11 @@ impl Parser { if !lexer.skip(Token::Separator(';')) { let num_statements = block.stmts.len(); - let (_, span) = lexer.capture_span(|lexer| { - self.statement(lexer, ctx.reborrow(), block) - })?; + let (_, span) = { + let ctx = &mut *ctx; + let block = &mut *block; + lexer.capture_span(|lexer| self.statement(lexer, ctx, block))? + }; if block.stmts.len() != num_statements { match block.stmts.last().unwrap().kind { @@ -1858,7 +1843,7 @@ impl Parser { let mut body = ast::Block::default(); if !lexer.skip(Token::Separator(';')) { let (condition, span) = lexer.capture_span(|lexer| { - let condition = self.general_expression(lexer, ctx.reborrow())?; + let condition = self.general_expression(lexer, ctx)?; lexer.expect(Token::Separator(';'))?; Ok(condition) })?; @@ -1881,13 +1866,13 @@ impl Parser { if !lexer.skip(Token::Paren(')')) { self.function_call_or_assignment_statement( lexer, - ctx.reborrow(), + ctx, &mut continuing, )?; lexer.expect(Token::Paren(')'))?; } - let (block, span) = self.block(lexer, ctx.reborrow())?; + let (block, span) = self.block(lexer, ctx)?; body.stmts.push(ast::Statement { kind: ast::StatementKind::Block(block), span, @@ -1926,7 +1911,7 @@ impl Parser { } // assignment or a function call _ => { - self.function_call_or_assignment_statement(lexer, ctx.reborrow(), block)?; + self.function_call_or_assignment_statement(lexer, ctx, block)?; lexer.expect(Token::Separator(';'))?; self.pop_rule_span(lexer); return Ok(()); @@ -1937,7 +1922,7 @@ impl Parser { block.stmts.push(ast::Statement { kind, span }); } _ => { - self.assignment_statement(lexer, ctx.reborrow(), block)?; + self.assignment_statement(lexer, ctx, block)?; lexer.expect(Token::Separator(';'))?; self.pop_rule_span(lexer); } @@ -1948,7 +1933,7 @@ impl Parser { fn r#loop<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result, Error<'a>> { let _ = lexer.next(); let mut body = ast::Block::default(); @@ -1976,7 +1961,7 @@ impl Parser { // the break if lexer.expect(Token::Word("if"))?; - let condition = self.general_expression(lexer, ctx.reborrow())?; + let condition = self.general_expression(lexer, ctx)?; // Set the condition of the break if to the newly parsed // expression break_if = Some(condition); @@ -1994,7 +1979,7 @@ impl Parser { break; } else { // Otherwise try to parse a statement - self.statement(lexer, ctx.reborrow(), &mut continuing)?; + self.statement(lexer, ctx, &mut continuing)?; } } // Since the continuing block must be the last part of the loop body, @@ -2008,7 +1993,7 @@ impl Parser { break; } // Otherwise try to parse a statement - self.statement(lexer, ctx.reborrow(), &mut body)?; + self.statement(lexer, ctx, &mut body)?; } ctx.local_table.pop_scope(); @@ -2024,7 +2009,7 @@ impl Parser { fn block<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result<(ast::Block<'a>, Span), Error<'a>> { self.push_rule_span(Rule::Block, lexer); @@ -2033,7 +2018,7 @@ impl Parser { lexer.expect(Token::Paren('{'))?; let mut block = ast::Block::default(); while !lexer.skip(Token::Paren('}')) { - self.statement(lexer, ctx.reborrow(), &mut block)?; + self.statement(lexer, ctx, &mut block)?; } ctx.local_table.pop_scope(); @@ -2045,14 +2030,14 @@ impl Parser { fn varying_binding<'a>( &mut self, lexer: &mut Lexer<'a>, - mut ctx: ExpressionContext<'a, '_, '_>, + ctx: &mut ExpressionContext<'a, '_, '_>, ) -> Result>, Error<'a>> { let mut bind_parser = BindingParser::default(); self.push_rule_span(Rule::Attribute, lexer); while lexer.skip(Token::Attribute) { let (word, span) = lexer.next_ident_with_span()?; - bind_parser.parse(self, lexer, word, span, ctx.reborrow())?; + bind_parser.parse(self, lexer, word, span, ctx)?; } let span = self.pop_rule_span(lexer); @@ -2093,12 +2078,12 @@ impl Parser { ExpectedToken::Token(Token::Separator(',')), )); } - let binding = self.varying_binding(lexer, ctx.reborrow())?; + let binding = self.varying_binding(lexer, &mut ctx)?; let param_name = lexer.next_ident()?; lexer.expect(Token::Separator(':'))?; - let param_type = self.type_decl(lexer, ctx.reborrow())?; + let param_type = self.type_decl(lexer, &mut ctx)?; let handle = ctx.declare_local(param_name)?; arguments.push(ast::FunctionArgument { @@ -2111,8 +2096,8 @@ impl Parser { } // read return type let result = if lexer.skip(Token::Arrow) && !lexer.skip(Token::Word("void")) { - let binding = self.varying_binding(lexer, ctx.reborrow())?; - let ty = self.type_decl(lexer, ctx.reborrow())?; + let binding = self.varying_binding(lexer, &mut ctx)?; + let ty = self.type_decl(lexer, &mut ctx)?; Some(ast::FunctionResult { ty, binding }) } else { None @@ -2122,7 +2107,7 @@ impl Parser { lexer.expect(Token::Paren('{'))?; let mut body = ast::Block::default(); while !lexer.skip(Token::Paren('}')) { - self.statement(lexer, ctx.reborrow(), &mut body)?; + self.statement(lexer, &mut ctx, &mut body)?; } ctx.local_table.pop_scope(); @@ -2170,12 +2155,12 @@ impl Parser { match lexer.next_ident_with_span()? { ("binding", name_span) => { lexer.expect(Token::Paren('('))?; - bind_index.set(self.general_expression(lexer, ctx.reborrow())?, name_span)?; + bind_index.set(self.general_expression(lexer, &mut ctx)?, name_span)?; lexer.expect(Token::Paren(')'))?; } ("group", name_span) => { lexer.expect(Token::Paren('('))?; - bind_group.set(self.general_expression(lexer, ctx.reborrow())?, name_span)?; + bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?; lexer.expect(Token::Paren(')'))?; } ("vertex", name_span) => { @@ -2192,7 +2177,7 @@ impl Parser { lexer.expect(Token::Paren('('))?; let mut new_workgroup_size = [None; 3]; for (i, size) in new_workgroup_size.iter_mut().enumerate() { - *size = Some(self.general_expression(lexer, ctx.reborrow())?); + *size = Some(self.general_expression(lexer, &mut ctx)?); match lexer.next() { (Token::Paren(')'), _) => break, (Token::Separator(','), _) if i != 2 => (), @@ -2241,14 +2226,14 @@ impl Parser { (Token::Word("struct"), _) => { let name = lexer.next_ident()?; - let members = self.struct_body(lexer, ctx)?; + let members = self.struct_body(lexer, &mut ctx)?; Some(ast::GlobalDeclKind::Struct(ast::Struct { name, members })) } (Token::Word("alias"), _) => { let name = lexer.next_ident()?; lexer.expect(Token::Operation('='))?; - let ty = self.type_decl(lexer, ctx)?; + let ty = self.type_decl(lexer, &mut ctx)?; lexer.expect(Token::Separator(';'))?; Some(ast::GlobalDeclKind::Type(ast::TypeAlias { name, ty })) } @@ -2256,20 +2241,20 @@ impl Parser { let name = lexer.next_ident()?; let ty = if lexer.skip(Token::Separator(':')) { - let ty = self.type_decl(lexer, ctx.reborrow())?; + let ty = self.type_decl(lexer, &mut ctx)?; Some(ty) } else { None }; lexer.expect(Token::Operation('='))?; - let init = self.general_expression(lexer, ctx)?; + let init = self.general_expression(lexer, &mut ctx)?; lexer.expect(Token::Separator(';'))?; Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init })) } (Token::Word("var"), _) => { - let mut var = self.variable_decl(lexer, ctx)?; + let mut var = self.variable_decl(lexer, &mut ctx)?; var.binding = binding.take(); Some(ast::GlobalDeclKind::Var(var)) } From 1ecb111842c08c4912a521a15fc2a670a3e9022f Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 19 Oct 2023 17:07:41 -0700 Subject: [PATCH 2/4] [wgsl-in] Delete front::wgsl::lower::GlobalContext::reborrow. Remove `GlobalContext::reborrow` in favor of Rust's automatic reborrowing of `&mut` references. --- src/front/wgsl/lower/construction.rs | 4 +- src/front/wgsl/lower/mod.rs | 58 ++++++++++++---------------- 2 files changed, 26 insertions(+), 36 deletions(-) diff --git a/src/front/wgsl/lower/construction.rs b/src/front/wgsl/lower/construction.rs index ec3a338706..48d1f9bbfc 100644 --- a/src/front/wgsl/lower/construction.rs +++ b/src/front/wgsl/lower/construction.rs @@ -579,8 +579,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::ConstructorType::PartialArray => ConcreteConstructorHandle::PartialArray, ast::ConstructorType::Array { base, size } => { - let base = self.resolve_ast_type(base, ctx.as_global())?; - let size = self.array_size(size, ctx.as_global())?; + let base = self.resolve_ast_type(base, &mut ctx.as_global())?; + let size = self.array_size(size, &mut ctx.as_global())?; self.layouter.update(ctx.module.to_ctx()).unwrap(); let stride = self.layouter[base].to_stride(); diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index ae178cb702..f6d0cc1856 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -86,16 +86,6 @@ pub struct GlobalContext<'source, 'temp, 'out> { } impl<'source> GlobalContext<'source, '_, '_> { - fn reborrow(&mut self) -> GlobalContext<'source, '_, '_> { - GlobalContext { - ast_expressions: self.ast_expressions, - globals: self.globals, - types: self.types, - module: self.module, - const_typifier: self.const_typifier, - } - } - fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> { ExpressionContext { ast_expressions: self.ast_expressions, @@ -917,11 +907,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { match decl.kind { ast::GlobalDeclKind::Fn(ref f) => { - let lowered_decl = self.function(f, span, ctx.reborrow())?; + let lowered_decl = self.function(f, span, &mut ctx)?; ctx.globals.insert(f.name.name, lowered_decl); } ast::GlobalDeclKind::Var(ref v) => { - let ty = self.resolve_ast_type(v.ty, ctx.reborrow())?; + let ty = self.resolve_ast_type(v.ty, &mut ctx)?; let init = v .init @@ -957,7 +947,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let inferred_type = ectx.register_type(init)?; let explicit_ty = - c.ty.map(|ty| self.resolve_ast_type(ty, ctx.reborrow())) + c.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx)) .transpose()?; if let Some(explicit) = explicit_ty { @@ -996,12 +986,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .insert(c.name.name, LoweredGlobalDecl::Const(handle)); } ast::GlobalDeclKind::Struct(ref s) => { - let handle = self.r#struct(s, span, ctx.reborrow())?; + let handle = self.r#struct(s, span, &mut ctx)?; ctx.globals .insert(s.name.name, LoweredGlobalDecl::Type(handle)); } ast::GlobalDeclKind::Type(ref alias) => { - let ty = self.resolve_ast_type(alias.ty, ctx.reborrow())?; + let ty = self.resolve_ast_type(alias.ty, &mut ctx)?; ctx.globals .insert(alias.name.name, LoweredGlobalDecl::Type(ty)); } @@ -1015,7 +1005,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &mut self, f: &ast::Function<'source>, span: Span, - mut ctx: GlobalContext<'source, '_, '_>, + ctx: &mut GlobalContext<'source, '_, '_>, ) -> Result> { let mut local_table = FastHashMap::default(); let mut local_variables = Arena::new(); @@ -1027,7 +1017,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .iter() .enumerate() .map(|(i, arg)| { - let ty = self.resolve_ast_type(arg.ty, ctx.reborrow())?; + let ty = self.resolve_ast_type(arg.ty, ctx)?; let expr = expressions .append(crate::Expression::FunctionArgument(i as u32), arg.name.span); local_table.insert(arg.handle, TypedExpression::non_reference(expr)); @@ -1036,7 +1026,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(crate::FunctionArgument { name: Some(arg.name.name.to_string()), ty, - binding: self.binding(&arg.binding, ty, ctx.reborrow())?, + binding: self.binding(&arg.binding, ty, ctx)?, }) }) .collect::, _>>()?; @@ -1045,10 +1035,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .result .as_ref() .map(|res| { - let ty = self.resolve_ast_type(res.ty, ctx.reborrow())?; + let ty = self.resolve_ast_type(res.ty, ctx)?; Ok(crate::FunctionResult { ty, - binding: self.binding(&res.binding, ty, ctx.reborrow())?, + binding: self.binding(&res.binding, ty, ctx)?, }) }) .transpose()?; @@ -1157,7 +1147,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ctx.expression_constness.force_non_const(value); let explicit_ty = - l.ty.map(|ty| self.resolve_ast_type(ty, ctx.as_global())) + l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global())) .transpose()?; if let Some(ty) = explicit_ty { @@ -1195,7 +1185,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }; let explicit_ty = - v.ty.map(|ty| self.resolve_ast_type(ty, ctx.as_global())) + v.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global())) .transpose()?; let ty = match (explicit_ty, initializer) { @@ -1735,7 +1725,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::Expression::Bitcast { expr, to, ty_span } => { let expr = self.expression(expr, ctx.reborrow())?; - let to_resolved = self.resolve_ast_type(to, ctx.as_global())?; + let to_resolved = self.resolve_ast_type(to, &mut ctx.as_global())?; let kind = match ctx.module.types[to_resolved].inner { crate::TypeInner::Scalar { kind, .. } => kind, @@ -2486,14 +2476,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &mut self, s: &ast::Struct<'source>, span: Span, - mut ctx: GlobalContext<'source, '_, '_>, + ctx: &mut GlobalContext<'source, '_, '_>, ) -> Result, Error<'source>> { let mut offset = 0; let mut struct_alignment = Alignment::ONE; let mut members = Vec::with_capacity(s.members.len()); for member in s.members.iter() { - let ty = self.resolve_ast_type(member.ty, ctx.reborrow())?; + let ty = self.resolve_ast_type(member.ty, ctx)?; self.layouter.update(ctx.module.to_ctx()).unwrap(); @@ -2526,7 +2516,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { member_min_alignment }; - let binding = self.binding(&member.binding, ty, ctx.reborrow())?; + let binding = self.binding(&member.binding, ty, ctx)?; offset = member_alignment.round_up(offset); struct_alignment = struct_alignment.max(member_alignment); @@ -2580,7 +2570,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fn array_size( &mut self, size: ast::ArraySize<'source>, - mut ctx: GlobalContext<'source, '_, '_>, + ctx: &mut GlobalContext<'source, '_, '_>, ) -> Result> { Ok(match size { ast::ArraySize::Constant(expr) => { @@ -2609,7 +2599,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fn resolve_ast_type( &mut self, handle: Handle>, - mut ctx: GlobalContext<'source, '_, '_>, + ctx: &mut GlobalContext<'source, '_, '_>, ) -> Result, Error<'source>> { let inner = match ctx.types[handle] { ast::Type::Scalar { kind, width } => crate::TypeInner::Scalar { kind, width }, @@ -2627,12 +2617,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, ast::Type::Atomic { kind, width } => crate::TypeInner::Atomic { kind, width }, ast::Type::Pointer { base, space } => { - let base = self.resolve_ast_type(base, ctx.reborrow())?; + let base = self.resolve_ast_type(base, ctx)?; crate::TypeInner::Pointer { base, space } } ast::Type::Array { base, size } => { - let base = self.resolve_ast_type(base, ctx.reborrow())?; - let size = self.array_size(size, ctx.reborrow())?; + let base = self.resolve_ast_type(base, ctx)?; + let size = self.array_size(size, ctx)?; self.layouter.update(ctx.module.to_ctx()).unwrap(); let stride = self.layouter[base].to_stride(); @@ -2652,8 +2642,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::Type::AccelerationStructure => crate::TypeInner::AccelerationStructure, ast::Type::RayQuery => crate::TypeInner::RayQuery, ast::Type::BindingArray { base, size } => { - let base = self.resolve_ast_type(base, ctx.reborrow())?; - let size = self.array_size(size, ctx.reborrow())?; + let base = self.resolve_ast_type(base, ctx)?; + let size = self.array_size(size, ctx)?; crate::TypeInner::BindingArray { base, size } } ast::Type::RayDesc => { @@ -2678,7 +2668,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &mut self, binding: &Option>, ty: Handle, - mut ctx: GlobalContext<'source, '_, '_>, + ctx: &mut GlobalContext<'source, '_, '_>, ) -> Result, Error<'source>> { Ok(match *binding { Some(ast::Binding::BuiltIn(b)) => Some(crate::Binding::BuiltIn(b)), From 38886a1399c4f8400c03d6905a0a5fe6d99e546b Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 19 Oct 2023 17:14:08 -0700 Subject: [PATCH 3/4] [wgsl-in] Delete front::wgsl::lower::StatementContext::reborrow. Remove `StatementContext::reborrow` in favor of Rust's automatic reborrowing of `&mut` references. --- src/front/wgsl/lower/mod.rs | 68 +++++++++++++------------------------ 1 file changed, 24 insertions(+), 44 deletions(-) diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index f6d0cc1856..3dfc05c9ef 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -163,23 +163,6 @@ pub struct StatementContext<'source, 'temp, 'out> { } impl<'a, 'temp> StatementContext<'a, 'temp, '_> { - fn reborrow(&mut self) -> StatementContext<'a, '_, '_> { - StatementContext { - local_table: self.local_table, - globals: self.globals, - types: self.types, - ast_expressions: self.ast_expressions, - const_typifier: self.const_typifier, - typifier: self.typifier, - variables: self.variables, - naga_expressions: self.naga_expressions, - named_expressions: self.named_expressions, - arguments: self.arguments, - module: self.module, - expression_constness: self.expression_constness, - } - } - fn as_expression<'t>( &'t mut self, block: &'t mut crate::Block, @@ -1044,24 +1027,21 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .transpose()?; let mut typifier = Typifier::default(); - let mut body = self.block( - &f.body, - false, - StatementContext { - local_table: &mut local_table, - globals: ctx.globals, - ast_expressions: ctx.ast_expressions, - const_typifier: ctx.const_typifier, - typifier: &mut typifier, - variables: &mut local_variables, - naga_expressions: &mut expressions, - named_expressions: &mut named_expressions, - types: ctx.types, - module: ctx.module, - arguments: &arguments, - expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(), - }, - )?; + let mut stmt_ctx = StatementContext { + local_table: &mut local_table, + globals: ctx.globals, + ast_expressions: ctx.ast_expressions, + const_typifier: ctx.const_typifier, + typifier: &mut typifier, + variables: &mut local_variables, + naga_expressions: &mut expressions, + named_expressions: &mut named_expressions, + types: ctx.types, + module: ctx.module, + arguments: &arguments, + expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(), + }; + let mut body = self.block(&f.body, false, &mut stmt_ctx)?; ensure_block_returns(&mut body); let function = crate::Function { @@ -1109,12 +1089,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &mut self, b: &ast::Block<'source>, is_inside_loop: bool, - mut ctx: StatementContext<'source, '_, '_>, + ctx: &mut StatementContext<'source, '_, '_>, ) -> Result> { let mut block = crate::Block::default(); for stmt in b.stmts.iter() { - self.statement(stmt, &mut block, is_inside_loop, ctx.reborrow())?; + self.statement(stmt, &mut block, is_inside_loop, ctx)?; } Ok(block) @@ -1125,11 +1105,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { stmt: &ast::Statement<'source>, block: &mut crate::Block, is_inside_loop: bool, - mut ctx: StatementContext<'source, '_, '_>, + ctx: &mut StatementContext<'source, '_, '_>, ) -> Result<(), Error<'source>> { let out = match stmt.kind { ast::StatementKind::Block(ref block) => { - let block = self.block(block, is_inside_loop, ctx.reborrow())?; + let block = self.block(block, is_inside_loop, ctx)?; crate::Statement::Block(block) } ast::StatementKind::LocalDecl(ref decl) => match *decl { @@ -1276,8 +1256,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { self.expression(condition, ctx.as_expression(block, &mut emitter))?; block.extend(emitter.finish(ctx.naga_expressions)); - let accept = self.block(accept, is_inside_loop, ctx.reborrow())?; - let reject = self.block(reject, is_inside_loop, ctx.reborrow())?; + let accept = self.block(accept, is_inside_loop, ctx)?; + let reject = self.block(reject, is_inside_loop, ctx)?; crate::Statement::If { condition, @@ -1321,7 +1301,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::SwitchValue::Default => crate::SwitchValue::Default, }, - body: self.block(&case.body, is_inside_loop, ctx.reborrow())?, + body: self.block(&case.body, is_inside_loop, ctx)?, fall_through: case.fall_through, }) }) @@ -1334,8 +1314,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ref continuing, break_if, } => { - let body = self.block(body, true, ctx.reborrow())?; - let mut continuing = self.block(continuing, true, ctx.reborrow())?; + let body = self.block(body, true, ctx)?; + let mut continuing = self.block(continuing, true, ctx)?; let mut emitter = Emitter::default(); emitter.start(ctx.naga_expressions); From 3ac47f0065416d8bf0b669f0e8853244df0fdef8 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 19 Oct 2023 17:31:58 -0700 Subject: [PATCH 4/4] [wgsl-in] Delete {ExpressionContext,RuntimeContext}::reborrow. Remove `front::wgsl::lower::ExpressionContext::reborrow` and `front::wgsl::lower::RuntimeExpressionContext::reborrow` in favor of Rust's automatic reborrowing of `&mut` references. --- src/front/wgsl/lower/construction.rs | 16 +- src/front/wgsl/lower/mod.rs | 273 ++++++++++++--------------- 2 files changed, 129 insertions(+), 160 deletions(-) diff --git a/src/front/wgsl/lower/construction.rs b/src/front/wgsl/lower/construction.rs index 48d1f9bbfc..1c46ecfd31 100644 --- a/src/front/wgsl/lower/construction.rs +++ b/src/front/wgsl/lower/construction.rs @@ -45,7 +45,7 @@ enum ConcreteConstructor<'a> { } impl ConcreteConstructorHandle { - fn to_error_string(&self, ctx: ExpressionContext) -> String { + fn to_error_string(&self, ctx: &mut ExpressionContext) -> String { match *self { Self::PartialVector { size } => { format!("vec{}", size as u32,) @@ -143,15 +143,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { constructor: &ast::ConstructorType<'source>, ty_span: Span, components: &[Handle>], - mut ctx: ExpressionContext<'source, '_, '_>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { - let constructor_h = self.constructor(constructor, ctx.reborrow())?; + let constructor_h = self.constructor(constructor, ctx)?; let components_h = match *components { [] => ComponentsHandle::None, [component] => { let span = ctx.ast_expressions.get_span(component); - let component = self.expression(component, ctx.reborrow())?; + let component = self.expression(component, ctx)?; let ty = super::resolve!(ctx, component); ComponentsHandle::One { @@ -162,12 +162,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } [component, ref rest @ ..] => { let span = ctx.ast_expressions.get_span(component); - let component = self.expression(component, ctx.reborrow())?; + let component = self.expression(component, ctx)?; let components = std::iter::once(Ok(component)) .chain( rest.iter() - .map(|&component| self.expression(component, ctx.reborrow())), + .map(|&component| self.expression(component, ctx)), ) .collect::>()?; let spans = std::iter::once(span) @@ -491,7 +491,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Err(Error::BadTypeCast { span, from_type, - to_type: constructor_h.to_error_string(ctx.reborrow()), + to_type: constructor_h.to_error_string(ctx), }); } @@ -548,7 +548,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fn constructor<'out>( &mut self, constructor: &ast::ConstructorType<'source>, - mut ctx: ExpressionContext<'source, '_, 'out>, + ctx: &mut ExpressionContext<'source, '_, 'out>, ) -> Result> { let c = match *constructor { ast::ConstructorType::Scalar { width, kind } => { diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 3dfc05c9ef..9176acca3f 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -235,21 +235,6 @@ pub struct RuntimeExpressionContext<'temp, 'out> { expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, } -impl RuntimeExpressionContext<'_, '_> { - fn reborrow(&mut self) -> RuntimeExpressionContext<'_, '_> { - RuntimeExpressionContext { - local_table: self.local_table, - naga_expressions: self.naga_expressions, - local_vars: self.local_vars, - arguments: self.arguments, - block: self.block, - emitter: self.emitter, - typifier: self.typifier, - expression_constness: self.expression_constness, - } - } -} - /// The type of Naga IR expression we are lowering an [`ast::Expression`] to. pub enum ExpressionContextType<'temp, 'out> { /// We are lowering to an arbitrary runtime expression, to be @@ -295,9 +280,6 @@ pub enum ExpressionContextType<'temp, 'out> { /// expressions, via [`Expression::Constant`], but constant /// expressions can't refer to a function's expressions. /// -/// - You can always call [`ExpressionContext::reborrow`] to get a fresh context -/// for a recursive call. The reborrowed context is equivalent to the original. -/// /// Not to be confused with `wgsl::parse::ExpressionContext`, which is /// for parsing the `ast::Expression` in the first place. /// @@ -335,22 +317,6 @@ pub struct ExpressionContext<'source, 'temp, 'out> { } impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { - fn reborrow(&mut self) -> ExpressionContext<'source, '_, '_> { - ExpressionContext { - globals: self.globals, - types: self.types, - ast_expressions: self.ast_expressions, - const_typifier: self.const_typifier, - module: self.module, - expr_type: match self.expr_type { - ExpressionContextType::Runtime(ref mut c) => { - ExpressionContextType::Runtime(c.reborrow()) - } - ExpressionContextType::Constant => ExpressionContextType::Constant, - }, - } - } - fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> { ExpressionContext { globals: self.globals, @@ -438,9 +404,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { fn runtime_expression_ctx( &mut self, span: Span, - ) -> Result, Error<'source>> { + ) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> { match self.expr_type { - ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx.reborrow()), + ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx), ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)), } } @@ -898,13 +864,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let init = v .init - .map(|init| self.expression(init, ctx.as_const())) + .map(|init| self.expression(init, &mut ctx.as_const())) .transpose()?; let binding = if let Some(ref binding) = v.binding { Some(crate::ResourceBinding { - group: self.const_u32(binding.group, ctx.as_const())?.0, - binding: self.const_u32(binding.binding, ctx.as_const())?.0, + group: self.const_u32(binding.group, &mut ctx.as_const())?.0, + binding: self.const_u32(binding.binding, &mut ctx.as_const())?.0, }) } else { None @@ -926,7 +892,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::GlobalDeclKind::Const(ref c) => { let mut ectx = ctx.as_const(); - let init = self.expression(c.init, ectx.reborrow())?; + let init = self.expression(c.init, &mut ectx)?; let inferred_type = ectx.register_type(init)?; let explicit_ty = @@ -1063,7 +1029,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut workgroup_size_out = [1; 3]; for (i, size) in workgroup_size.into_iter().enumerate() { if let Some(size_expr) = size { - workgroup_size_out[i] = self.const_u32(size_expr, ctx.as_const())?.0; + workgroup_size_out[i] = self.const_u32(size_expr, &mut ctx.as_const())?.0; } } workgroup_size_out @@ -1117,7 +1083,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut emitter = Emitter::default(); emitter.start(ctx.naga_expressions); - let value = self.expression(l.init, ctx.as_expression(block, &mut emitter))?; + let value = + self.expression(l.init, &mut ctx.as_expression(block, &mut emitter))?; // The WGSL spec says that any expression that refers to a // `let`-bound variable is not a const expression. This @@ -1158,9 +1125,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { emitter.start(ctx.naga_expressions); let initializer = match v.init { - Some(init) => { - Some(self.expression(init, ctx.as_expression(block, &mut emitter))?) - } + Some(init) => Some( + self.expression(init, &mut ctx.as_expression(block, &mut emitter))?, + ), None => None, }; @@ -1253,7 +1220,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { emitter.start(ctx.naga_expressions); let condition = - self.expression(condition, ctx.as_expression(block, &mut emitter))?; + self.expression(condition, &mut ctx.as_expression(block, &mut emitter))?; block.extend(emitter.finish(ctx.naga_expressions)); let accept = self.block(accept, is_inside_loop, ctx)?; @@ -1273,7 +1240,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { emitter.start(ctx.naga_expressions); let mut ectx = ctx.as_expression(block, &mut emitter); - let selector = self.expression(selector, ectx.reborrow())?; + let selector = self.expression(selector, &mut ectx)?; let uint = resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint); @@ -1286,7 +1253,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { value: match case.value { ast::SwitchValue::Expr(expr) => { let span = ctx.ast_expressions.get_span(expr); - let expr = self.expression(expr, ctx.as_global().as_const())?; + let expr = + self.expression(expr, &mut ctx.as_global().as_const())?; match ctx.module.to_ctx().eval_expr_to_literal(expr) { Some(crate::Literal::I32(value)) if !uint => { crate::SwitchValue::I32(value) @@ -1320,7 +1288,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut emitter = Emitter::default(); emitter.start(ctx.naga_expressions); let break_if = break_if - .map(|expr| self.expression(expr, ctx.as_expression(block, &mut emitter))) + .map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter))) .transpose()?; continuing.extend(emitter.finish(ctx.naga_expressions)); @@ -1337,7 +1305,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { emitter.start(ctx.naga_expressions); let value = value - .map(|expr| self.expression(expr, ctx.as_expression(block, &mut emitter))) + .map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter))) .transpose()?; block.extend(emitter.finish(ctx.naga_expressions)); @@ -1355,7 +1323,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { stmt.span, function, arguments, - ctx.as_expression(block, &mut emitter), + &mut ctx.as_expression(block, &mut emitter), )?; block.extend(emitter.finish(ctx.naga_expressions)); return Ok(()); @@ -1364,9 +1332,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut emitter = Emitter::default(); emitter.start(ctx.naga_expressions); - let expr = - self.expression_for_reference(target, ctx.as_expression(block, &mut emitter))?; - let mut value = self.expression(value, ctx.as_expression(block, &mut emitter))?; + let expr = self.expression_for_reference( + target, + &mut ctx.as_expression(block, &mut emitter), + )?; + let mut value = + self.expression(value, &mut ctx.as_expression(block, &mut emitter))?; if !expr.is_reference { let ty = ctx.invalid_assignment_type(expr.handle); @@ -1411,8 +1382,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }; let value_span = ctx.ast_expressions.get_span(value); - let reference = - self.expression_for_reference(value, ctx.as_expression(block, &mut emitter))?; + let reference = self + .expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?; let mut ectx = ctx.as_expression(block, &mut emitter); let (kind, width) = match *resolve_inner!(ectx, reference.handle) { @@ -1459,7 +1430,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut emitter = Emitter::default(); emitter.start(ctx.naga_expressions); - let _ = self.expression(expr, ctx.as_expression(block, &mut emitter))?; + let _ = self.expression(expr, &mut ctx.as_expression(block, &mut emitter))?; block.extend(emitter.finish(ctx.naga_expressions)); return Ok(()); } @@ -1473,16 +1444,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fn expression( &mut self, expr: Handle>, - mut ctx: ExpressionContext<'source, '_, '_>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { - let expr = self.expression_for_reference(expr, ctx.reborrow())?; + let expr = self.expression_for_reference(expr, ctx)?; ctx.apply_load_rule(expr) } fn expression_for_reference( &mut self, expr: Handle>, - mut ctx: ExpressionContext<'source, '_, '_>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result> { let span = ctx.ast_expressions.get_span(expr); let expr = &ctx.ast_expressions[expr]; @@ -1535,17 +1506,17 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty_span, ref components, } => { - let handle = self.construct(span, ty, ty_span, components, ctx.reborrow())?; + let handle = self.construct(span, ty, ty_span, components, ctx)?; return Ok(TypedExpression::non_reference(handle)); } ast::Expression::Unary { op, expr } => { - let expr = self.expression(expr, ctx.reborrow())?; + let expr = self.expression(expr, ctx)?; (crate::Expression::Unary { op, expr }, false) } ast::Expression::AddrOf(expr) => { // The `&` operator simply converts a reference to a pointer. And since a // reference is required, the Load Rule is not applied. - let expr = self.expression_for_reference(expr, ctx.reborrow())?; + let expr = self.expression_for_reference(expr, ctx)?; if !expr.is_reference { return Err(Error::NotReference("the operand of the `&` operator", span)); } @@ -1558,7 +1529,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::Expression::Deref(expr) => { // The pointer we dereference must be loaded. - let pointer = self.expression(expr, ctx.reborrow())?; + let pointer = self.expression(expr, ctx)?; if resolve_inner!(ctx, pointer).pointer_space().is_none() { return Err(Error::NotPointer(span)); @@ -1571,8 +1542,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::Expression::Binary { op, left, right } => { // Load both operands. - let mut left = self.expression(left, ctx.reborrow())?; - let mut right = self.expression(right, ctx.reborrow())?; + let mut left = self.expression(left, ctx)?; + let mut right = self.expression(right, ctx)?; ctx.binary_op_splat(op, &mut left, &mut right)?; (crate::Expression::Binary { op, left, right }, false) } @@ -1581,13 +1552,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ref arguments, } => { let handle = self - .call(span, function, arguments, ctx.reborrow())? + .call(span, function, arguments, ctx)? .ok_or(Error::FunctionReturnsVoid(function.span))?; return Ok(TypedExpression::non_reference(handle)); } ast::Expression::Index { base, index } => { - let expr = self.expression_for_reference(base, ctx.reborrow())?; - let index = self.expression(index, ctx.reborrow())?; + let expr = self.expression_for_reference(base, ctx)?; + let index = self.expression(index, ctx)?; let wgsl_pointer = resolve_inner!(ctx, expr.handle).pointer_space().is_some() && !expr.is_reference; @@ -1621,7 +1592,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let TypedExpression { handle, is_reference, - } = self.expression_for_reference(base, ctx.reborrow())?; + } = self.expression_for_reference(base, ctx)?; let temp_inner; let (composite, wgsl_pointer) = match *resolve_inner!(ctx, handle) { @@ -1704,7 +1675,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { access } ast::Expression::Bitcast { expr, to, ty_span } => { - let expr = self.expression(expr, ctx.reborrow())?; + let expr = self.expression(expr, ctx)?; let to_resolved = self.resolve_ast_type(to, &mut ctx.as_global())?; let kind = match ctx.module.types[to_resolved].inner { @@ -1761,7 +1732,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span: Span, function: &ast::Ident<'source>, arguments: &[Handle>], - mut ctx: ExpressionContext<'source, '_, '_>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result>, Error<'source>> { match ctx.globals.get(function.name) { Some(&LoweredGlobalDecl::Type(ty)) => { @@ -1770,7 +1741,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &ast::ConstructorType::Type(ty), function.span, arguments, - ctx.reborrow(), + ctx, )?; Ok(Some(handle)) } @@ -1781,7 +1752,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Some(&LoweredGlobalDecl::Function(function)) => { let arguments = arguments .iter() - .map(|&arg| self.expression(arg, ctx.reborrow())) + .map(|&arg| self.expression(arg, ctx)) .collect::, _>>()?; let has_result = ctx.module.functions[function].result.is_some(); @@ -1809,7 +1780,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let span = function.span; let expr = if let Some(fun) = conv::map_relational_fun(function.name) { let mut args = ctx.prepare_args(arguments, 1, span); - let argument = self.expression(args.next()?, ctx.reborrow())?; + let argument = self.expression(args.next()?, ctx)?; args.finish()?; // Check for no-op all(bool) and any(bool): @@ -1833,7 +1804,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) { let mut args = ctx.prepare_args(arguments, 1, span); - let expr = self.expression(args.next()?, ctx.reborrow())?; + let expr = self.expression(args.next()?, ctx)?; args.finish()?; crate::Expression::Derivative { axis, ctrl, expr } @@ -1841,20 +1812,20 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let expected = fun.argument_count() as _; let mut args = ctx.prepare_args(arguments, expected, span); - let arg = self.expression(args.next()?, ctx.reborrow())?; + let arg = self.expression(args.next()?, ctx)?; let arg1 = args .next() - .map(|x| self.expression(x, ctx.reborrow())) + .map(|x| self.expression(x, ctx)) .ok() .transpose()?; let arg2 = args .next() - .map(|x| self.expression(x, ctx.reborrow())) + .map(|x| self.expression(x, ctx)) .ok() .transpose()?; let arg3 = args .next() - .map(|x| self.expression(x, ctx.reborrow())) + .map(|x| self.expression(x, ctx)) .ok() .transpose()?; @@ -1886,15 +1857,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { arg3, } } else if let Some(fun) = Texture::map(function.name) { - self.texture_sample_helper(fun, arguments, span, ctx.reborrow())? + self.texture_sample_helper(fun, arguments, span, ctx)? } else { match function.name { "select" => { let mut args = ctx.prepare_args(arguments, 3, span); - let reject = self.expression(args.next()?, ctx.reborrow())?; - let accept = self.expression(args.next()?, ctx.reborrow())?; - let condition = self.expression(args.next()?, ctx.reborrow())?; + let reject = self.expression(args.next()?, ctx)?; + let accept = self.expression(args.next()?, ctx)?; + let condition = self.expression(args.next()?, ctx)?; args.finish()?; @@ -1906,22 +1877,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "arrayLength" => { let mut args = ctx.prepare_args(arguments, 1, span); - let expr = self.expression(args.next()?, ctx.reborrow())?; + let expr = self.expression(args.next()?, ctx)?; args.finish()?; crate::Expression::ArrayLength(expr) } "atomicLoad" => { let mut args = ctx.prepare_args(arguments, 1, span); - let pointer = self.atomic_pointer(args.next()?, ctx.reborrow())?; + let pointer = self.atomic_pointer(args.next()?, ctx)?; args.finish()?; crate::Expression::Load { pointer } } "atomicStore" => { let mut args = ctx.prepare_args(arguments, 2, span); - let pointer = self.atomic_pointer(args.next()?, ctx.reborrow())?; - let value = self.expression(args.next()?, ctx.reborrow())?; + let pointer = self.atomic_pointer(args.next()?, ctx)?; + let value = self.expression(args.next()?, ctx)?; args.finish()?; let rctx = ctx.runtime_expression_ctx(span)?; @@ -1937,7 +1908,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, crate::AtomicFunction::Add, arguments, - ctx.reborrow(), + ctx, )?)) } "atomicSub" => { @@ -1945,7 +1916,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, crate::AtomicFunction::Subtract, arguments, - ctx.reborrow(), + ctx, )?)) } "atomicAnd" => { @@ -1953,7 +1924,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, crate::AtomicFunction::And, arguments, - ctx.reborrow(), + ctx, )?)) } "atomicOr" => { @@ -1961,7 +1932,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, crate::AtomicFunction::InclusiveOr, arguments, - ctx.reborrow(), + ctx, )?)) } "atomicXor" => { @@ -1969,7 +1940,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, crate::AtomicFunction::ExclusiveOr, arguments, - ctx.reborrow(), + ctx, )?)) } "atomicMin" => { @@ -1977,7 +1948,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, crate::AtomicFunction::Min, arguments, - ctx.reborrow(), + ctx, )?)) } "atomicMax" => { @@ -1985,7 +1956,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, crate::AtomicFunction::Max, arguments, - ctx.reborrow(), + ctx, )?)) } "atomicExchange" => { @@ -1993,19 +1964,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span, crate::AtomicFunction::Exchange { compare: None }, arguments, - ctx.reborrow(), + ctx, )?)) } "atomicCompareExchangeWeak" => { let mut args = ctx.prepare_args(arguments, 3, span); - let pointer = self.atomic_pointer(args.next()?, ctx.reborrow())?; + let pointer = self.atomic_pointer(args.next()?, ctx)?; - let compare = self.expression(args.next()?, ctx.reborrow())?; + let compare = self.expression(args.next()?, ctx)?; let value = args.next()?; let value_span = ctx.ast_expressions.get_span(value); - let value = self.expression(value, ctx.reborrow())?; + let value = self.expression(value, ctx)?; args.finish()?; @@ -2060,7 +2031,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let expr = args.next()?; args.finish()?; - let pointer = self.expression(expr, ctx.reborrow())?; + let pointer = self.expression(expr, ctx)?; let result_ty = match *resolve_inner!(ctx, pointer) { crate::TypeInner::Pointer { base, @@ -2089,19 +2060,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let image = args.next()?; let image_span = ctx.ast_expressions.get_span(image); - let image = self.expression(image, ctx.reborrow())?; + let image = self.expression(image, ctx)?; - let coordinate = self.expression(args.next()?, ctx.reborrow())?; + let coordinate = self.expression(args.next()?, ctx)?; let (_, arrayed) = ctx.image_data(image, image_span)?; let array_index = arrayed .then(|| { args.min_args += 1; - self.expression(args.next()?, ctx.reborrow()) + self.expression(args.next()?, ctx) }) .transpose()?; - let value = self.expression(args.next()?, ctx.reborrow())?; + let value = self.expression(args.next()?, ctx)?; args.finish()?; @@ -2123,26 +2094,26 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let image = args.next()?; let image_span = ctx.ast_expressions.get_span(image); - let image = self.expression(image, ctx.reborrow())?; + let image = self.expression(image, ctx)?; - let coordinate = self.expression(args.next()?, ctx.reborrow())?; + let coordinate = self.expression(args.next()?, ctx)?; let (class, arrayed) = ctx.image_data(image, image_span)?; let array_index = arrayed .then(|| { args.min_args += 1; - self.expression(args.next()?, ctx.reborrow()) + self.expression(args.next()?, ctx) }) .transpose()?; let level = class .is_mipmapped() - .then(|| self.expression(args.next()?, ctx.reborrow())) + .then(|| self.expression(args.next()?, ctx)) .transpose()?; let sample = class .is_multisampled() - .then(|| self.expression(args.next()?, ctx.reborrow())) + .then(|| self.expression(args.next()?, ctx)) .transpose()?; args.finish()?; @@ -2157,10 +2128,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "textureDimensions" => { let mut args = ctx.prepare_args(arguments, 1, span); - let image = self.expression(args.next()?, ctx.reborrow())?; + let image = self.expression(args.next()?, ctx)?; let level = args .next() - .map(|arg| self.expression(arg, ctx.reborrow())) + .map(|arg| self.expression(arg, ctx)) .ok() .transpose()?; args.finish()?; @@ -2172,7 +2143,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "textureNumLevels" => { let mut args = ctx.prepare_args(arguments, 1, span); - let image = self.expression(args.next()?, ctx.reborrow())?; + let image = self.expression(args.next()?, ctx)?; args.finish()?; crate::Expression::ImageQuery { @@ -2182,7 +2153,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "textureNumLayers" => { let mut args = ctx.prepare_args(arguments, 1, span); - let image = self.expression(args.next()?, ctx.reborrow())?; + let image = self.expression(args.next()?, ctx)?; args.finish()?; crate::Expression::ImageQuery { @@ -2192,7 +2163,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "textureNumSamples" => { let mut args = ctx.prepare_args(arguments, 1, span); - let image = self.expression(args.next()?, ctx.reborrow())?; + let image = self.expression(args.next()?, ctx)?; args.finish()?; crate::Expression::ImageQuery { @@ -2202,10 +2173,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "rayQueryInitialize" => { let mut args = ctx.prepare_args(arguments, 3, span); - let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?; - let acceleration_structure = - self.expression(args.next()?, ctx.reborrow())?; - let descriptor = self.expression(args.next()?, ctx.reborrow())?; + let query = self.ray_query_pointer(args.next()?, ctx)?; + let acceleration_structure = self.expression(args.next()?, ctx)?; + let descriptor = self.expression(args.next()?, ctx)?; args.finish()?; let _ = ctx.module.generate_ray_desc_type(); @@ -2224,7 +2194,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "rayQueryProceed" => { let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?; + let query = self.ray_query_pointer(args.next()?, ctx)?; args.finish()?; let result = ctx.interrupt_emitter( @@ -2239,7 +2209,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "rayQueryGetCommittedIntersection" => { let mut args = ctx.prepare_args(arguments, 1, span); - let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?; + let query = self.ray_query_pointer(args.next()?, ctx)?; args.finish()?; let _ = ctx.module.generate_ray_intersection_type(); @@ -2256,7 +2226,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { &ast::ConstructorType::Type(ty), function.span, arguments, - ctx.reborrow(), + ctx, )?; return Ok(Some(handle)); } @@ -2273,10 +2243,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fn atomic_pointer( &mut self, expr: Handle>, - mut ctx: ExpressionContext<'source, '_, '_>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { let span = ctx.ast_expressions.get_span(expr); - let pointer = self.expression(expr, ctx.reborrow())?; + let pointer = self.expression(expr, ctx)?; match *resolve_inner!(ctx, pointer) { crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { @@ -2298,14 +2268,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { span: Span, fun: crate::AtomicFunction, args: &[Handle>], - mut ctx: ExpressionContext<'source, '_, '_>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { let mut args = ctx.prepare_args(args, 2, span); - let pointer = self.atomic_pointer(args.next()?, ctx.reborrow())?; + let pointer = self.atomic_pointer(args.next()?, ctx)?; let value = args.next()?; - let value = self.expression(value, ctx.reborrow())?; + let value = self.expression(value, ctx)?; let ty = ctx.register_type(value)?; args.finish()?; @@ -2335,7 +2305,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fun: Texture, args: &[Handle>], span: Span, - mut ctx: ExpressionContext<'source, '_, '_>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result> { let mut args = ctx.prepare_args(args, fun.min_argument_count(), span); @@ -2346,7 +2316,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ) -> Result<(Handle, Span), Error<'source>> { let image = args.next()?; let image_span = ctx.ast_expressions.get_span(image); - let image = lowerer.expression(image, ctx.reborrow())?; + let image = lowerer.expression(image, ctx)?; Ok((image, image_span)) } @@ -2355,8 +2325,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let image_or_component = args.next()?; let image_or_component_span = ctx.ast_expressions.get_span(image_or_component); // Gathers from depth textures don't take an initial `component` argument. - let lowered_image_or_component = - self.expression(image_or_component, ctx.reborrow())?; + let lowered_image_or_component = self.expression(image_or_component, ctx)?; match *resolve_inner!(ctx, lowered_image_or_component) { crate::TypeInner::Image { @@ -2368,7 +2337,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Some(crate::SwizzleComponent::X), ), _ => { - let (image, image_span) = get_image_and_span(self, &mut args, &mut ctx)?; + let (image, image_span) = get_image_and_span(self, &mut args, ctx)?; ( image, image_span, @@ -2382,59 +2351,59 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } Texture::GatherCompare => { - let (image, image_span) = get_image_and_span(self, &mut args, &mut ctx)?; + let (image, image_span) = get_image_and_span(self, &mut args, ctx)?; (image, image_span, Some(crate::SwizzleComponent::X)) } _ => { - let (image, image_span) = get_image_and_span(self, &mut args, &mut ctx)?; + let (image, image_span) = get_image_and_span(self, &mut args, ctx)?; (image, image_span, None) } }; - let sampler = self.expression(args.next()?, ctx.reborrow())?; + let sampler = self.expression(args.next()?, ctx)?; - let coordinate = self.expression(args.next()?, ctx.reborrow())?; + let coordinate = self.expression(args.next()?, ctx)?; let (_, arrayed) = ctx.image_data(image, image_span)?; let array_index = arrayed - .then(|| self.expression(args.next()?, ctx.reborrow())) + .then(|| self.expression(args.next()?, ctx)) .transpose()?; let (level, depth_ref) = match fun { Texture::Gather => (crate::SampleLevel::Zero, None), Texture::GatherCompare => { - let reference = self.expression(args.next()?, ctx.reborrow())?; + let reference = self.expression(args.next()?, ctx)?; (crate::SampleLevel::Zero, Some(reference)) } Texture::Sample => (crate::SampleLevel::Auto, None), Texture::SampleBias => { - let bias = self.expression(args.next()?, ctx.reborrow())?; + let bias = self.expression(args.next()?, ctx)?; (crate::SampleLevel::Bias(bias), None) } Texture::SampleCompare => { - let reference = self.expression(args.next()?, ctx.reborrow())?; + let reference = self.expression(args.next()?, ctx)?; (crate::SampleLevel::Auto, Some(reference)) } Texture::SampleCompareLevel => { - let reference = self.expression(args.next()?, ctx.reborrow())?; + let reference = self.expression(args.next()?, ctx)?; (crate::SampleLevel::Zero, Some(reference)) } Texture::SampleGrad => { - let x = self.expression(args.next()?, ctx.reborrow())?; - let y = self.expression(args.next()?, ctx.reborrow())?; + let x = self.expression(args.next()?, ctx)?; + let y = self.expression(args.next()?, ctx)?; (crate::SampleLevel::Gradient { x, y }, None) } Texture::SampleLevel => { - let level = self.expression(args.next()?, ctx.reborrow())?; + let level = self.expression(args.next()?, ctx)?; (crate::SampleLevel::Exact(level), None) } }; let offset = args .next() - .map(|arg| self.expression(arg, ctx.as_const())) + .map(|arg| self.expression(arg, &mut ctx.as_const())) .ok() .transpose()?; @@ -2471,7 +2440,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let member_min_alignment = self.layouter[ty].alignment; let member_size = if let Some(size_expr) = member.size { - let (size, span) = self.const_u32(size_expr, ctx.as_const())?; + let (size, span) = self.const_u32(size_expr, &mut ctx.as_const())?; if size < member_min_size { return Err(Error::SizeAttributeTooLow(span, member_min_size)); } else { @@ -2482,7 +2451,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }; let member_alignment = if let Some(align_expr) = member.align { - let (align, span) = self.const_u32(align_expr, ctx.as_const())?; + let (align, span) = self.const_u32(align_expr, &mut ctx.as_const())?; if let Some(alignment) = Alignment::new(align) { if alignment < member_min_alignment { return Err(Error::AlignAttributeTooLow(span, member_min_alignment)); @@ -2530,10 +2499,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fn const_u32( &mut self, expr: Handle>, - mut ctx: ExpressionContext<'source, '_, '_>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result<(u32, Span), Error<'source>> { let span = ctx.ast_expressions.get_span(expr); - let expr = self.expression(expr, ctx.reborrow())?; + let expr = self.expression(expr, ctx)?; let value = ctx .module .to_ctx() @@ -2555,7 +2524,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(match size { ast::ArraySize::Constant(expr) => { let span = ctx.ast_expressions.get_span(expr); - let const_expr = self.expression(expr, ctx.as_const())?; + let const_expr = self.expression(expr, &mut ctx.as_const())?; let len = ctx.module .to_ctx() @@ -2659,7 +2628,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { sampling, }) => { let mut binding = crate::Binding::Location { - location: self.const_u32(location, ctx.as_const())?.0, + location: self.const_u32(location, &mut ctx.as_const())?.0, second_blend_source, interpolation, sampling, @@ -2674,10 +2643,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fn ray_query_pointer( &mut self, expr: Handle>, - mut ctx: ExpressionContext<'source, '_, '_>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { let span = ctx.ast_expressions.get_span(expr); - let pointer = self.expression(expr, ctx.reborrow())?; + let pointer = self.expression(expr, ctx)?; match *resolve_inner!(ctx, pointer) { crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {