diff --git a/src/front/glsl/context.rs b/src/front/glsl/context.rs index 0f836def69..0fe0f734a3 100644 --- a/src/front/glsl/context.rs +++ b/src/front/glsl/context.rs @@ -245,19 +245,6 @@ impl<'a> Context<'a> { } pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result> { - let mut append = |arena: &mut Arena, expr: Expression, span| { - let is_running = self.emitter.is_running(); - let needs_pre_emit = expr.needs_pre_emit(); - if is_running && needs_pre_emit { - self.body.extend(self.emitter.finish(arena)); - } - let h = arena.append(expr, span); - if is_running && needs_pre_emit { - self.emitter.start(arena); - } - h - }; - let (expressions, const_expressions) = if self.is_const { (&mut self.module.const_expressions, None) } else { @@ -269,7 +256,10 @@ impl<'a> Context<'a> { constants: &self.module.constants, expressions, const_expressions, - append: (!self.is_const).then_some(&mut append), + emitter: (!self.is_const).then_some(crate::proc::ConstantEvaluatorEmitter { + emitter: &mut self.emitter, + block: &mut self.body, + }), }; let res = eval.try_eval_and_append(&expr, meta).map_err(|e| Error { @@ -280,7 +270,17 @@ impl<'a> Context<'a> { match res { Ok(expr) => Ok(expr), Err(e) if self.is_const => Err(e), - Err(_) => Ok(append(&mut self.expressions, expr, meta)), + Err(_) => { + let needs_pre_emit = expr.needs_pre_emit(); + if needs_pre_emit { + self.body.extend(self.emitter.finish(expressions)); + } + let h = expressions.append(expr, meta); + if needs_pre_emit { + self.emitter.start(expressions); + } + Ok(h) + } } } diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 5ece316640..1973cdbe41 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -6,8 +6,8 @@ use crate::front::wgsl::parse::number::Number; use crate::front::wgsl::parse::{ast, conv}; use crate::front::Typifier; use crate::proc::{ - ensure_block_returns, Alignment, ConstantEvaluator, Emitter, Layouter, ResolveContext, - TypeResolution, + ensure_block_returns, Alignment, ConstantEvaluator, ConstantEvaluatorEmitter, Emitter, + Layouter, ResolveContext, TypeResolution, }; use crate::{Arena, FastHashMap, Handle, Span}; use indexmap::IndexMap; @@ -328,20 +328,10 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { constants: &self.module.constants, expressions: rctx.naga_expressions, const_expressions: Some(&self.module.const_expressions), - append: Some( - |arena: &mut Arena, expr: crate::Expression, span| { - let is_running = rctx.emitter.is_running(); - let needs_pre_emit = expr.needs_pre_emit(); - if is_running && needs_pre_emit { - rctx.block.extend(rctx.emitter.finish(arena)); - } - let h = arena.append(expr, span); - if is_running && needs_pre_emit { - rctx.emitter.start(arena); - } - h - }, - ), + emitter: Some(ConstantEvaluatorEmitter { + emitter: rctx.emitter, + block: rctx.block, + }), }; let res = eval @@ -359,15 +349,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { constants: &self.module.constants, expressions: &mut self.module.const_expressions, const_expressions: None, - append: None::< - Box< - dyn FnMut( - &mut Arena, - crate::Expression, - Span, - ) -> Handle, - >, - >, + emitter: None, }; eval.try_eval_and_append(&expr, span) diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 1d390dfc66..19b635072e 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -5,15 +5,22 @@ use crate::{ }; #[derive(Debug)] -pub struct ConstantEvaluator< - 'a, - F: FnMut(&mut Arena, Expression, Span) -> Handle, -> { +pub struct ConstantEvaluator<'a> { pub types: &'a mut UniqueArena, pub constants: &'a Arena, pub expressions: &'a mut Arena, pub const_expressions: Option<&'a Arena>, - pub append: Option, + + /// When `expressions` refers to a function's local expression + /// arena, this is the emitter we should interrupt when inserting + /// new things into it. + pub emitter: Option>, +} + +#[derive(Debug)] +pub struct ConstantEvaluatorEmitter<'a> { + pub emitter: &'a mut super::Emitter, + pub block: &'a mut crate::Block, } #[derive(Debug)] @@ -106,9 +113,7 @@ impl Arena { } } -impl<'a, F: FnMut(&mut Arena, Expression, Span) -> Handle> - ConstantEvaluator<'a, F> -{ +impl ConstantEvaluator<'_> { fn check_and_get( &mut self, expr: Handle, @@ -807,11 +812,20 @@ impl<'a, F: FnMut(&mut Arena, Expression, Span) -> Handle Handle { - if let Some(ref mut append) = self.append { - append(self.expressions, expr, span) - } else { - self.expressions.append(expr, span) + if let Some(ref mut emitter) = self.emitter { + let is_running = emitter.emitter.is_running(); + let needs_pre_emit = expr.needs_pre_emit(); + if is_running && needs_pre_emit { + emitter + .block + .extend(emitter.emitter.finish(self.expressions)); + let h = self.expressions.append(expr, span); + emitter.emitter.start(self.expressions); + return h; + } } + + self.expressions.append(expr, span) } } @@ -980,15 +994,7 @@ mod tests { constants: &constants, expressions: &mut const_expressions, const_expressions: None, - append: None::< - Box< - dyn FnMut( - &mut Arena, - Expression, - crate::Span, - ) -> crate::Handle, - >, - >, + emitter: None, }; let res1 = solver @@ -1075,15 +1081,7 @@ mod tests { constants: &constants, expressions: &mut const_expressions, const_expressions: None, - append: None::< - Box< - dyn FnMut( - &mut Arena, - Expression, - crate::Span, - ) -> crate::Handle, - >, - >, + emitter: None, }; let res = solver @@ -1202,15 +1200,7 @@ mod tests { constants: &constants, expressions: &mut const_expressions, const_expressions: None, - append: None::< - Box< - dyn FnMut( - &mut Arena, - Expression, - crate::Span, - ) -> crate::Handle, - >, - >, + emitter: None, }; let root1 = Expression::AccessIndex { base, index: 1 }; diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 0bf81fb173..35b88537c1 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -10,7 +10,7 @@ mod namer; mod terminator; mod typifier; -pub use constant_evaluator::{ConstantEvaluator, ConstantEvaluatorError}; +pub use constant_evaluator::{ConstantEvaluator, ConstantEvaluatorEmitter, ConstantEvaluatorError}; pub use emitter::Emitter; pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError}; pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};