From 802bede81f12ed770ef4a60686e3c3e4d4395167 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sat, 21 Oct 2023 17:29:14 -0700 Subject: [PATCH] [wgsl-in] Let lowering contexts point to a Function directly. Change `StatementContext` and `RuntimeExpressionContext` in `front::wgsl::lower` to hold a `&mut crate::Function` reference, rather than separate pointers to individual fields of the `Function`. This replaces three fields with one, and clarifies their relationships. --- src/front/wgsl/lower/mod.rs | 155 ++++++++++++++++++------------------ 1 file changed, 77 insertions(+), 78 deletions(-) diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 9176acca3f..ddec74b71f 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -142,12 +142,10 @@ pub struct StatementContext<'source, 'temp, 'out> { const_typifier: &'temp mut Typifier, typifier: &'temp mut Typifier, - variables: &'out mut Arena, - naga_expressions: &'out mut Arena, + function: &'out mut crate::Function, /// Stores the names of expressions that are assigned in `let` statement /// Also stores the spans of the names, for use in errors. named_expressions: &'out mut FastIndexMap, (String, Span)>, - arguments: &'out [crate::FunctionArgument], module: &'out mut crate::Module, /// Which `Expression`s in `self.naga_expressions` are const expressions, in @@ -179,12 +177,10 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { module: self.module, expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext { local_table: self.local_table, - naga_expressions: self.naga_expressions, - local_vars: self.variables, - arguments: self.arguments, - typifier: self.typifier, + function: self.function, block, emitter, + typifier: self.typifier, expression_constness: self.expression_constness, }), } @@ -204,7 +200,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { if let Some(&(_, span)) = self.named_expressions.get(&expr) { InvalidAssignmentType::ImmutableBinding(span) } else { - match self.naga_expressions[expr] { + match self.function.expressions[expr] { crate::Expression::Swizzle { .. } => InvalidAssignmentType::Swizzle, crate::Expression::Access { base, .. } => self.invalid_assignment_type(base), crate::Expression::AccessIndex { base, .. } => self.invalid_assignment_type(base), @@ -221,9 +217,7 @@ pub struct RuntimeExpressionContext<'temp, 'out> { /// enclosing statement; see that documentation for details. local_table: &'temp FastHashMap, TypedExpression>, - naga_expressions: &'out mut Arena, - local_vars: &'out Arena, - arguments: &'out [crate::FunctionArgument], + function: &'out mut crate::Function, block: &'temp mut crate::Block, emitter: &'temp mut Emitter, typifier: &'temp mut Typifier, @@ -342,7 +336,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { match self.expr_type { ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function( self.module, - rctx.naga_expressions, + &mut rctx.function.expressions, rctx.expression_constness, rctx.emitter, rctx.block, @@ -364,7 +358,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { // long as we're not building `Module::const_expressions`. Err(err) => match self.expr_type { ExpressionContextType::Runtime(ref mut rctx) => { - Ok(rctx.naga_expressions.append(expr, span)) + Ok(rctx.function.expressions.append(expr, span)) } ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)), }, @@ -380,7 +374,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { self.module .to_ctx() - .eval_expr_to_u32_from(handle, ctx.naga_expressions) + .eval_expr_to_u32_from(handle, &ctx.function.expressions) .ok() } ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(), @@ -389,7 +383,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { fn get_expression_span(&self, handle: Handle) -> Span { match self.expr_type { - ExpressionContextType::Runtime(ref ctx) => ctx.naga_expressions.get_span(handle), + ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle), ExpressionContextType::Constant => self.module.const_expressions.get_span(handle), } } @@ -428,7 +422,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { let index = self .module .to_ctx() - .eval_expr_to_u32_from(expr, rctx.naga_expressions) + .eval_expr_to_u32_from(expr, &rctx.function.expressions) .map_err(|err| match err { crate::proc::U32EvalError::NonConst => { Error::ExpectedConstExprConcreteIntegerScalar(component_span) @@ -478,7 +472,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { /// Resolve the types of all expressions up through `handle`. /// /// Ensure that [`self.typifier`] has a [`TypeResolution`] for - /// every expression in [`self.naga_expressions`]. + /// every expression in [`self.function.expressions`]. /// /// This does not add types to any arena. The [`Typifier`] /// documentation explains the steps we take to avoid filling @@ -499,20 +493,23 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { handle: Handle, ) -> Result<&mut Self, Error<'source>> { let empty_arena = Arena::new(); - let resolve_ctx = match self.expr_type { - ExpressionContextType::Runtime(ref ctx) => { - ResolveContext::with_locals(self.module, ctx.local_vars, ctx.arguments) - } - ExpressionContextType::Constant => { - ResolveContext::with_locals(self.module, &empty_arena, &[]) - } - }; - let (typifier, expressions) = match self.expr_type { + let resolve_ctx; + let typifier; + let expressions; + match self.expr_type { ExpressionContextType::Runtime(ref mut ctx) => { - (&mut *ctx.typifier, &*ctx.naga_expressions) + resolve_ctx = ResolveContext::with_locals( + self.module, + &ctx.function.local_variables, + &ctx.function.arguments, + ); + typifier = &mut *ctx.typifier; + expressions = &ctx.function.expressions; } ExpressionContextType::Constant => { - (&mut *self.const_typifier, &self.module.const_expressions) + resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]); + typifier = self.const_typifier; + expressions = &self.module.const_expressions; } }; typifier @@ -603,14 +600,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { match self.expr_type { ExpressionContextType::Runtime(ref mut rctx) => { rctx.block - .extend(rctx.emitter.finish(rctx.naga_expressions)); + .extend(rctx.emitter.finish(&rctx.function.expressions)); } ExpressionContextType::Constant => {} } let result = self.append_expression(expression, span); match self.expr_type { ExpressionContextType::Runtime(ref mut rctx) => { - rctx.emitter.start(rctx.naga_expressions); + rctx.emitter.start(&rctx.function.expressions); } ExpressionContextType::Constant => {} } @@ -957,7 +954,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ctx: &mut GlobalContext<'source, '_, '_>, ) -> Result> { let mut local_table = FastHashMap::default(); - let mut local_variables = Arena::new(); let mut expressions = Arena::new(); let mut named_expressions = FastIndexMap::default(); @@ -992,6 +988,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }) .transpose()?; + let mut function = crate::Function { + name: Some(f.name.name.to_string()), + arguments, + result, + local_variables: Arena::new(), + expressions, + named_expressions: crate::NamedExpressions::default(), + body: crate::Block::default(), + }; + let mut typifier = Typifier::default(); let mut stmt_ctx = StatementContext { local_table: &mut local_table, @@ -999,29 +1005,20 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast_expressions: ctx.ast_expressions, const_typifier: ctx.const_typifier, typifier: &mut typifier, - variables: &mut local_variables, - naga_expressions: &mut expressions, + function: &mut function, named_expressions: &mut named_expressions, types: ctx.types, module: ctx.module, - arguments: &arguments, expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(), }; let mut body = self.block(&f.body, false, &mut stmt_ctx)?; ensure_block_returns(&mut body); - let function = crate::Function { - name: Some(f.name.name.to_string()), - arguments, - result, - local_variables, - expressions, - named_expressions: named_expressions - .into_iter() - .map(|(key, (name, _))| (key, name)) - .collect(), - body, - }; + function.body = body; + function.named_expressions = named_expressions + .into_iter() + .map(|(key, (name, _))| (key, name)) + .collect(); if let Some(ref entry) = f.entry_point { let workgroup_size = if let Some(workgroup_size) = entry.workgroup_size { @@ -1081,7 +1078,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::StatementKind::LocalDecl(ref decl) => match *decl { ast::LocalDecl::Let(ref l) => { let mut emitter = Emitter::default(); - emitter.start(ctx.naga_expressions); + emitter.start(&ctx.function.expressions); let value = self.expression(l.init, &mut ctx.as_expression(block, &mut emitter))?; @@ -1112,7 +1109,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } - block.extend(emitter.finish(ctx.naga_expressions)); + block.extend(emitter.finish(&ctx.function.expressions)); ctx.local_table .insert(l.handle, TypedExpression::non_reference(value)); ctx.named_expressions @@ -1122,7 +1119,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::LocalDecl::Var(ref v) => { let mut emitter = Emitter::default(); - emitter.start(ctx.naga_expressions); + emitter.start(&ctx.function.expressions); let initializer = match v.init { Some(init) => Some( @@ -1180,7 +1177,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } }; - let var = ctx.variables.append( + let var = ctx.function.local_variables.append( crate::LocalVariable { name: Some(v.name.name.to_string()), ty, @@ -1193,7 +1190,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { crate::Expression::LocalVariable(var), Span::UNDEFINED, )?; - block.extend(emitter.finish(ctx.naga_expressions)); + block.extend(emitter.finish(&ctx.function.expressions)); ctx.local_table.insert( v.handle, TypedExpression { @@ -1217,11 +1214,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ref reject, } => { let mut emitter = Emitter::default(); - emitter.start(ctx.naga_expressions); + emitter.start(&ctx.function.expressions); let condition = self.expression(condition, &mut ctx.as_expression(block, &mut emitter))?; - block.extend(emitter.finish(ctx.naga_expressions)); + block.extend(emitter.finish(&ctx.function.expressions)); let accept = self.block(accept, is_inside_loop, ctx)?; let reject = self.block(reject, is_inside_loop, ctx)?; @@ -1237,14 +1234,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ref cases, } => { let mut emitter = Emitter::default(); - emitter.start(ctx.naga_expressions); + emitter.start(&ctx.function.expressions); let mut ectx = ctx.as_expression(block, &mut emitter); let selector = self.expression(selector, &mut ectx)?; let uint = resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint); - block.extend(emitter.finish(ctx.naga_expressions)); + block.extend(emitter.finish(&ctx.function.expressions)); let cases = cases .iter() @@ -1286,11 +1283,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut continuing = self.block(continuing, true, ctx)?; let mut emitter = Emitter::default(); - emitter.start(ctx.naga_expressions); + emitter.start(&ctx.function.expressions); let break_if = break_if .map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter))) .transpose()?; - continuing.extend(emitter.finish(ctx.naga_expressions)); + continuing.extend(emitter.finish(&ctx.function.expressions)); crate::Statement::Loop { body, @@ -1302,12 +1299,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::StatementKind::Continue => crate::Statement::Continue, ast::StatementKind::Return { value } => { let mut emitter = Emitter::default(); - emitter.start(ctx.naga_expressions); + emitter.start(&ctx.function.expressions); let value = value .map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter))) .transpose()?; - block.extend(emitter.finish(ctx.naga_expressions)); + block.extend(emitter.finish(&ctx.function.expressions)); crate::Statement::Return { value } } @@ -1317,7 +1314,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ref arguments, } => { let mut emitter = Emitter::default(); - emitter.start(ctx.naga_expressions); + emitter.start(&ctx.function.expressions); let _ = self.call( stmt.span, @@ -1325,12 +1322,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { arguments, &mut ctx.as_expression(block, &mut emitter), )?; - block.extend(emitter.finish(ctx.naga_expressions)); + block.extend(emitter.finish(&ctx.function.expressions)); return Ok(()); } ast::StatementKind::Assign { target, op, value } => { let mut emitter = Emitter::default(); - emitter.start(ctx.naga_expressions); + emitter.start(&ctx.function.expressions); let expr = self.expression_for_reference( target, @@ -1364,7 +1361,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } None => value, }; - block.extend(emitter.finish(ctx.naga_expressions)); + block.extend(emitter.finish(&ctx.function.expressions)); crate::Statement::Store { pointer: expr.handle, @@ -1373,7 +1370,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::StatementKind::Increment(value) | ast::StatementKind::Decrement(value) => { let mut emitter = Emitter::default(); - emitter.start(ctx.naga_expressions); + emitter.start(&ctx.function.expressions); let op = match stmt.kind { ast::StatementKind::Increment(_) => crate::BinaryOperator::Add, @@ -1410,17 +1407,18 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let right = ectx.interrupt_emitter(crate::Expression::Literal(literal), Span::UNDEFINED)?; let rctx = ectx.runtime_expression_ctx(stmt.span)?; - let left = rctx.naga_expressions.append( + let left = rctx.function.expressions.append( crate::Expression::Load { pointer: reference.handle, }, value_span, ); let value = rctx - .naga_expressions + .function + .expressions .append(crate::Expression::Binary { op, left, right }, stmt.span); - block.extend(emitter.finish(ctx.naga_expressions)); + block.extend(emitter.finish(&ctx.function.expressions)); crate::Statement::Store { pointer: reference.handle, value, @@ -1428,10 +1426,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::StatementKind::Ignore(expr) => { let mut emitter = Emitter::default(); - emitter.start(ctx.naga_expressions); + emitter.start(&ctx.function.expressions); let _ = self.expression(expr, &mut ctx.as_expression(block, &mut emitter))?; - block.extend(emitter.finish(ctx.naga_expressions)); + block.extend(emitter.finish(&ctx.function.expressions)); return Ok(()); } }; @@ -1759,12 +1757,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; // we need to always do this before a fn call since all arguments need to be emitted before the fn call rctx.block - .extend(rctx.emitter.finish(rctx.naga_expressions)); + .extend(rctx.emitter.finish(&rctx.function.expressions)); let result = has_result.then(|| { - rctx.naga_expressions + rctx.function + .expressions .append(crate::Expression::CallResult(function), span) }); - rctx.emitter.start(rctx.naga_expressions); + rctx.emitter.start(&rctx.function.expressions); rctx.block.push( crate::Statement::Call { function, @@ -1897,8 +1896,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block - .extend(rctx.emitter.finish(rctx.naga_expressions)); - rctx.emitter.start(rctx.naga_expressions); + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); rctx.block .push(crate::Statement::Store { pointer, value }, span); return Ok(None); @@ -2078,8 +2077,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block - .extend(rctx.emitter.finish(rctx.naga_expressions)); - rctx.emitter.start(rctx.naga_expressions); + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); let stmt = crate::Statement::ImageStore { image, coordinate, @@ -2186,8 +2185,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block - .extend(rctx.emitter.finish(rctx.naga_expressions)); - rctx.emitter.start(rctx.naga_expressions); + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); rctx.block .push(crate::Statement::RayQuery { query, fun }, span); return Ok(None);