Skip to content

Commit

Permalink
Tighten up ConstantEvaluator's public API. (#2520)
Browse files Browse the repository at this point in the history
- Make the fields of `ConstantEvaluator` private to the module.
- Add constructor functions `for_module` and `for_function`.
- Make `FunctionLocalData` private.
  • Loading branch information
jimblandy authored Sep 27, 2023
1 parent b7b69ee commit 514571f
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 58 deletions.
46 changes: 20 additions & 26 deletions src/front/glsl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,45 +248,39 @@ impl<'a> Context<'a> {
}

pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
let (expressions, function_info) = if self.is_const {
(&mut self.module.const_expressions, None)
let mut eval = if self.is_const {
crate::proc::ConstantEvaluator::for_module(self.module)
} else {
(
crate::proc::ConstantEvaluator::for_function(
self.module,
&mut self.expressions,
Some(crate::proc::FunctionLocalData {
const_expressions: &self.module.const_expressions,
expression_constness: &mut self.expression_constness,
emitter: &mut self.emitter,
block: &mut self.body,
}),
&mut self.expression_constness,
&mut self.emitter,
&mut self.body,
)
};

let mut eval = crate::proc::ConstantEvaluator {
types: &mut self.module.types,
constants: &self.module.constants,
expressions,
function_local_data: function_info,
};

let res = eval.try_eval_and_append(&expr, meta).map_err(|e| Error {
kind: e.into(),
meta,
});

match res {
Ok(expr) => Ok(expr),
Err(e) if self.is_const => Err(e),
Err(_) => {
let needs_pre_emit = expr.needs_pre_emit();
if needs_pre_emit {
self.body.extend(self.emitter.finish(expressions));
}
let h = expressions.append(expr, meta);
if needs_pre_emit {
self.emitter.start(expressions);
Err(e) => {
if self.is_const {
Err(e)
} else {
let needs_pre_emit = expr.needs_pre_emit();
if needs_pre_emit {
self.body.extend(self.emitter.finish(&self.expressions));
}
let h = self.expressions.append(expr, meta);
if needs_pre_emit {
self.emitter.start(&self.expressions);
}
Ok(h)
}
Ok(h)
}
}
}
Expand Down
30 changes: 10 additions & 20 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::front::wgsl::parse::number::Number;
use crate::front::wgsl::parse::{ast, conv};
use crate::front::Typifier;
use crate::proc::{
ensure_block_returns, Alignment, ConstantEvaluator, Emitter, FunctionLocalData, Layouter,
ResolveContext, TypeResolution,
ensure_block_returns, Alignment, ConstantEvaluator, Emitter, Layouter, ResolveContext,
TypeResolution,
};
use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span};

Expand Down Expand Up @@ -340,31 +340,21 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
) -> Result<Handle<crate::Expression>, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref mut rctx) => {
let mut eval = ConstantEvaluator {
types: &mut self.module.types,
constants: &self.module.constants,
expressions: rctx.naga_expressions,
function_local_data: Some(FunctionLocalData {
const_expressions: &self.module.const_expressions,
expression_constness: rctx.expression_constness,
emitter: rctx.emitter,
block: rctx.block,
}),
};
let mut eval = ConstantEvaluator::for_function(
self.module,
rctx.naga_expressions,
rctx.expression_constness,
rctx.emitter,
rctx.block,
);

match eval.try_eval_and_append(&expr, span) {
Ok(expr) => Ok(expr),
Err(_) => Ok(rctx.naga_expressions.append(expr, span)),
}
}
ExpressionContextType::Constant => {
let mut eval = ConstantEvaluator {
types: &mut self.module.types,
constants: &self.module.constants,
expressions: &mut self.module.const_expressions,
function_local_data: None,
};

let mut eval = ConstantEvaluator::for_module(self.module);
eval.try_eval_and_append(&expr, span)
.map_err(|e| Error::ConstantEvaluatorError(e, span))
}
Expand Down
51 changes: 40 additions & 11 deletions src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@ use crate::{

#[derive(Debug)]
pub struct ConstantEvaluator<'a> {
pub types: &'a mut UniqueArena<Type>,
pub constants: &'a Arena<Constant>,
pub expressions: &'a mut Arena<Expression>,
types: &'a mut UniqueArena<Type>,
constants: &'a Arena<Constant>,
expressions: &'a mut Arena<Expression>,

/// When `self.expressions` refers to a function's local expression
/// arena, this needs to be populated
pub function_local_data: Option<FunctionLocalData<'a>>,
function_local_data: Option<FunctionLocalData<'a>>,
}

#[derive(Debug)]
pub struct FunctionLocalData<'a> {
struct FunctionLocalData<'a> {
/// Global constant expressions
pub const_expressions: &'a Arena<Expression>,
const_expressions: &'a Arena<Expression>,
/// Tracks the constness of expressions residing in `ConstantEvaluator.expressions`
pub expression_constness: &'a mut ExpressionConstnessTracker,
pub emitter: &'a mut super::Emitter,
pub block: &'a mut crate::Block,
expression_constness: &'a mut ExpressionConstnessTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
}

#[derive(Debug)]
Expand All @@ -37,7 +37,7 @@ impl ExpressionConstnessTracker {
}
}

pub fn insert(&mut self, value: Handle<Expression>) {
fn insert(&mut self, value: Handle<Expression>) {
self.inner.insert(value.index());
}

Expand Down Expand Up @@ -137,7 +137,36 @@ pub enum ConstantEvaluatorError {
// Math
// As

impl ConstantEvaluator<'_> {
impl<'a> ConstantEvaluator<'a> {
pub fn for_module(module: &'a mut crate::Module) -> Self {
Self {
types: &mut module.types,
constants: &module.constants,
expressions: &mut module.const_expressions,
function_local_data: None,
}
}

pub fn for_function(
module: &'a mut crate::Module,
expressions: &'a mut Arena<Expression>,
expression_constness: &'a mut ExpressionConstnessTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
) -> Self {
Self {
types: &mut module.types,
constants: &module.constants,
expressions,
function_local_data: Some(FunctionLocalData {
const_expressions: &module.const_expressions,
expression_constness,
emitter,
block,
}),
}
}

fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
if let Some(ref extra_data) = self.function_local_data {
if !extra_data.expression_constness.is_const(expr) {
Expand Down
2 changes: 1 addition & 1 deletion src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod terminator;
mod typifier;

pub use constant_evaluator::{
ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker, FunctionLocalData,
ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker,
};
pub use emitter::Emitter;
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
Expand Down

0 comments on commit 514571f

Please sign in to comment.