Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tighten up ConstantEvaluator's public API. #2520

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 20 additions & 26 deletions src/front/glsl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,45 +248,39 @@ impl<'a> Context<'a> {
}

pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
let (expressions, function_info) = if self.is_const {
(&mut self.module.const_expressions, None)
let mut eval = if self.is_const {
crate::proc::ConstantEvaluator::for_module(self.module)
} else {
(
crate::proc::ConstantEvaluator::for_function(
self.module,
&mut self.expressions,
Some(crate::proc::FunctionLocalData {
const_expressions: &self.module.const_expressions,
expression_constness: &mut self.expression_constness,
emitter: &mut self.emitter,
block: &mut self.body,
}),
&mut self.expression_constness,
&mut self.emitter,
&mut self.body,
)
};

let mut eval = crate::proc::ConstantEvaluator {
types: &mut self.module.types,
constants: &self.module.constants,
expressions,
function_local_data: function_info,
};

let res = eval.try_eval_and_append(&expr, meta).map_err(|e| Error {
kind: e.into(),
meta,
});

match res {
Ok(expr) => Ok(expr),
Err(e) if self.is_const => Err(e),
Err(_) => {
let needs_pre_emit = expr.needs_pre_emit();
if needs_pre_emit {
self.body.extend(self.emitter.finish(expressions));
}
let h = expressions.append(expr, meta);
if needs_pre_emit {
self.emitter.start(expressions);
Err(e) => {
if self.is_const {
Err(e)
} else {
let needs_pre_emit = expr.needs_pre_emit();
if needs_pre_emit {
self.body.extend(self.emitter.finish(&self.expressions));
}
let h = self.expressions.append(expr, meta);
if needs_pre_emit {
self.emitter.start(&self.expressions);
}
Ok(h)
}
Ok(h)
}
}
}
Expand Down
30 changes: 10 additions & 20 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::front::wgsl::parse::number::Number;
use crate::front::wgsl::parse::{ast, conv};
use crate::front::Typifier;
use crate::proc::{
ensure_block_returns, Alignment, ConstantEvaluator, Emitter, FunctionLocalData, Layouter,
ResolveContext, TypeResolution,
ensure_block_returns, Alignment, ConstantEvaluator, Emitter, Layouter, ResolveContext,
TypeResolution,
};
use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span};

Expand Down Expand Up @@ -340,31 +340,21 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
) -> Result<Handle<crate::Expression>, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref mut rctx) => {
let mut eval = ConstantEvaluator {
types: &mut self.module.types,
constants: &self.module.constants,
expressions: rctx.naga_expressions,
function_local_data: Some(FunctionLocalData {
const_expressions: &self.module.const_expressions,
expression_constness: rctx.expression_constness,
emitter: rctx.emitter,
block: rctx.block,
}),
};
let mut eval = ConstantEvaluator::for_function(
self.module,
rctx.naga_expressions,
rctx.expression_constness,
rctx.emitter,
rctx.block,
);

match eval.try_eval_and_append(&expr, span) {
Ok(expr) => Ok(expr),
Err(_) => Ok(rctx.naga_expressions.append(expr, span)),
}
}
ExpressionContextType::Constant => {
let mut eval = ConstantEvaluator {
types: &mut self.module.types,
constants: &self.module.constants,
expressions: &mut self.module.const_expressions,
function_local_data: None,
};

let mut eval = ConstantEvaluator::for_module(self.module);
eval.try_eval_and_append(&expr, span)
.map_err(|e| Error::ConstantEvaluatorError(e, span))
}
Expand Down
51 changes: 40 additions & 11 deletions src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@ use crate::{

#[derive(Debug)]
pub struct ConstantEvaluator<'a> {
pub types: &'a mut UniqueArena<Type>,
pub constants: &'a Arena<Constant>,
pub expressions: &'a mut Arena<Expression>,
types: &'a mut UniqueArena<Type>,
constants: &'a Arena<Constant>,
expressions: &'a mut Arena<Expression>,

/// When `self.expressions` refers to a function's local expression
/// arena, this needs to be populated
pub function_local_data: Option<FunctionLocalData<'a>>,
function_local_data: Option<FunctionLocalData<'a>>,
}

#[derive(Debug)]
pub struct FunctionLocalData<'a> {
struct FunctionLocalData<'a> {
/// Global constant expressions
pub const_expressions: &'a Arena<Expression>,
const_expressions: &'a Arena<Expression>,
/// Tracks the constness of expressions residing in `ConstantEvaluator.expressions`
pub expression_constness: &'a mut ExpressionConstnessTracker,
pub emitter: &'a mut super::Emitter,
pub block: &'a mut crate::Block,
expression_constness: &'a mut ExpressionConstnessTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
}

#[derive(Debug)]
Expand All @@ -37,7 +37,7 @@ impl ExpressionConstnessTracker {
}
}

pub fn insert(&mut self, value: Handle<Expression>) {
fn insert(&mut self, value: Handle<Expression>) {
self.inner.insert(value.index());
}

Expand Down Expand Up @@ -137,7 +137,36 @@ pub enum ConstantEvaluatorError {
// Math
// As

impl ConstantEvaluator<'_> {
impl<'a> ConstantEvaluator<'a> {
pub fn for_module(module: &'a mut crate::Module) -> Self {
Self {
types: &mut module.types,
constants: &module.constants,
expressions: &mut module.const_expressions,
function_local_data: None,
}
}

pub fn for_function(
module: &'a mut crate::Module,
expressions: &'a mut Arena<Expression>,
expression_constness: &'a mut ExpressionConstnessTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
) -> Self {
Self {
types: &mut module.types,
constants: &module.constants,
expressions,
function_local_data: Some(FunctionLocalData {
const_expressions: &module.const_expressions,
expression_constness,
emitter,
block,
}),
}
}

fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
if let Some(ref extra_data) = self.function_local_data {
if !extra_data.expression_constness.is_const(expr) {
Expand Down
2 changes: 1 addition & 1 deletion src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod terminator;
mod typifier;

pub use constant_evaluator::{
ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker, FunctionLocalData,
ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker,
};
pub use emitter::Emitter;
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
Expand Down