Skip to content

Commit

Permalink
use LocalVariable init
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Jun 13, 2023
1 parent 6574ec5 commit 6f64c99
Show file tree
Hide file tree
Showing 96 changed files with 3,984 additions and 4,404 deletions.
2 changes: 1 addition & 1 deletion src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1523,7 +1523,7 @@ impl<'a, W: Write> Writer<'a, W> {

// Write the constant
// `write_constant` adds no trailing or leading space/newline
self.write_const_expr(init)?;
self.write_expr(init, &ctx)?;
} else if is_value_init_supported(self.module, local.ty) {
write!(self.out, " = ")?;
self.write_zero_init_value(local.ty)?;
Expand Down
2 changes: 1 addition & 1 deletion src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, " = ")?;
// Write the local initializer if needed
if let Some(init) = local.init {
self.write_const_expression(module, init)?;
self.write_expr(module, init, func_ctx)?;
} else {
// Zero initialize local variables
self.write_default_init(module, local.ty)?;
Expand Down
70 changes: 36 additions & 34 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3459,6 +3459,23 @@ impl<W: Write> Writer<W> {

writeln!(self.out, ") {{")?;

let guarded_indices =
index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);

let context = StatementContext {
expression: ExpressionContext {
function: fun,
origin: FunctionOrigin::Handle(fun_handle),
info: fun_info,
policies: options.bounds_check_policies,
guarded_indices,
module,
mod_info,
pipeline_options,
},
result_struct: None,
};

for (local_handle, local) in fun.local_variables.iter() {
let ty_name = TypeContext {
handle: local.ty,
Expand All @@ -3473,7 +3490,7 @@ impl<W: Write> Writer<W> {
match local.init {
Some(value) => {
write!(self.out, " = ")?;
self.put_const_expression(value, module, mod_info)?;
self.put_expression(value, &context.expression, true)?;
}
None => {
write!(self.out, " = {{}}")?;
Expand All @@ -3482,22 +3499,6 @@ impl<W: Write> Writer<W> {
writeln!(self.out, ";")?;
}

let guarded_indices =
index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);

let context = StatementContext {
expression: ExpressionContext {
function: fun,
origin: FunctionOrigin::Handle(fun_handle),
info: fun_info,
policies: options.bounds_check_policies,
guarded_indices,
module,
mod_info,
pipeline_options,
},
result_struct: None,
};
self.named_expressions.clear();
self.update_expressions_to_bake(fun, fun_info, &context.expression);
self.put_block(back::Level(1), &fun.body, &context)?;
Expand Down Expand Up @@ -4015,6 +4016,23 @@ impl<W: Write> Writer<W> {
}
}

let guarded_indices =
index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);

let context = StatementContext {
expression: ExpressionContext {
function: fun,
origin: FunctionOrigin::EntryPoint(ep_index as _),
info: fun_info,
policies: options.bounds_check_policies,
guarded_indices,
module,
mod_info,
pipeline_options,
},
result_struct: Some(&stage_out_name),
};

// Finally, declare all the local variables that we need
//TODO: we can postpone this till the relevant expressions are emitted
for (local_handle, local) in fun.local_variables.iter() {
Expand All @@ -4031,7 +4049,7 @@ impl<W: Write> Writer<W> {
match local.init {
Some(value) => {
write!(self.out, " = ")?;
self.put_const_expression(value, module, mod_info)?;
self.put_expression(value, &context.expression, true)?;
}
None => {
write!(self.out, " = {{}}")?;
Expand All @@ -4040,22 +4058,6 @@ impl<W: Write> Writer<W> {
writeln!(self.out, ";")?;
}

let guarded_indices =
index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);

let context = StatementContext {
expression: ExpressionContext {
function: fun,
origin: FunctionOrigin::EntryPoint(ep_index as _),
info: fun_info,
policies: options.bounds_check_policies,
guarded_indices,
module,
mod_info,
pipeline_options,
},
result_struct: Some(&stage_out_name),
};
self.named_expressions.clear();
self.update_expressions_to_bake(fun, fun_info, &context.expression);
self.put_block(back::Level(1), &fun.body, &context)?;
Expand Down
55 changes: 36 additions & 19 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,39 @@ impl<'w> BlockContext<'w> {
self.temp_list.push(self.cached[component]);
}

let id = self.gen_id();
block.body.push(Instruction::composite_construct(
result_type_id,
id,
&self.temp_list,
));
id
if self.ir_function.expressions.is_const(expr_handle) {
let ty = self
.writer
.get_expression_lookup_type(&self.fun_info[expr_handle].ty);
self.writer.get_constant_composite(ty, &self.temp_list)
} else {
let id = self.gen_id();
block.body.push(Instruction::composite_construct(
result_type_id,
id,
&self.temp_list,
));
id
}
}
crate::Expression::Splat { size, value } => {
let value_id = self.cached[value];
let components = &[value_id; 4][..size as usize];

if self.ir_function.expressions.is_const(expr_handle) {
let ty = self
.writer
.get_expression_lookup_type(&self.fun_info[expr_handle].ty);
self.writer.get_constant_composite(ty, components)
} else {
let id = self.gen_id();
block.body.push(Instruction::composite_construct(
result_type_id,
id,
components,
));
id
}
}
crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => {
// See `is_intermediate`; we'll handle this later in
Expand Down Expand Up @@ -389,17 +415,6 @@ impl<'w> BlockContext<'w> {
crate::Expression::GlobalVariable(handle) => {
self.writer.global_variables[handle.index()].access_id
}
crate::Expression::Splat { size, value } => {
let value_id = self.cached[value];
let components = [value_id; 4];
let id = self.gen_id();
block.body.push(Instruction::composite_construct(
result_type_id,
id,
&components[..size as usize],
));
id
}
crate::Expression::Swizzle {
size,
vector,
Expand Down Expand Up @@ -1712,7 +1727,9 @@ impl<'w> BlockContext<'w> {
match *statement {
crate::Statement::Emit(ref range) => {
for handle in range.clone() {
self.cache_expression_value(handle, &mut block)?;
if !self.ir_function.expressions.is_const(handle) {
self.cache_expression_value(handle, &mut block)?;
}
}
}
crate::Statement::Block(ref block_statements) => {
Expand Down
74 changes: 42 additions & 32 deletions src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,37 +336,6 @@ impl Writer {
) -> Result<Word, Error> {
let mut function = Function::default();

for (handle, variable) in ir_function.local_variables.iter() {
let id = self.id_gen.next();

if self.flags.contains(WriterFlags::DEBUG) {
if let Some(ref name) = variable.name {
self.debugs.push(Instruction::name(id, name));
}
}

let init_word = variable
.init
.map(|constant| self.constant_ids[constant.index()]);
let pointer_type_id =
self.get_pointer_id(&ir_module.types, variable.ty, spirv::StorageClass::Function)?;
let instruction = Instruction::variable(
pointer_type_id,
id,
spirv::StorageClass::Function,
init_word.or_else(|| match ir_module.types[variable.ty].inner {
crate::TypeInner::RayQuery => None,
_ => {
let type_id = self.get_type_id(LookupType::Handle(variable.ty));
Some(self.write_constant_null(type_id))
}
}),
);
function
.variables
.insert(handle, LocalVariable { id, instruction });
}

let prelude_id = self.id_gen.next();
let mut prelude = Block::new(prelude_id);
let mut ep_context = EntryPointContext {
Expand Down Expand Up @@ -654,7 +623,48 @@ impl Writer {
// fill up the pre-emitted expressions
context.cached.reset(ir_function.expressions.len());
for (handle, expr) in ir_function.expressions.iter() {
if expr.needs_pre_emit() {
if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_)))
|| ir_function.expressions.is_const(handle)
{
context.cache_expression_value(handle, &mut prelude)?;
}
}

for (handle, variable) in ir_function.local_variables.iter() {
let id = context.gen_id();

if context.writer.flags.contains(WriterFlags::DEBUG) {
if let Some(ref name) = variable.name {
context.writer.debugs.push(Instruction::name(id, name));
}
}

let init_word = variable.init.map(|constant| context.cached[constant]);
let pointer_type_id = context.writer.get_pointer_id(
&ir_module.types,
variable.ty,
spirv::StorageClass::Function,
)?;
let instruction = Instruction::variable(
pointer_type_id,
id,
spirv::StorageClass::Function,
init_word.or_else(|| match ir_module.types[variable.ty].inner {
crate::TypeInner::RayQuery => None,
_ => {
let type_id = context.get_type_id(LookupType::Handle(variable.ty));
Some(context.writer.write_constant_null(type_id))
}
}),
);
context
.function
.variables
.insert(handle, LocalVariable { id, instruction });
}

for (handle, expr) in ir_function.expressions.iter() {
if expr.needs_pre_emit() && matches!(*expr, crate::Expression::LocalVariable(_)) {
context.cache_expression_value(handle, &mut prelude)?;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ impl<W: Write> Writer<W> {

// Write the constant
// `write_constant` adds no trailing or leading space/newline
self.write_const_expression(module, init)?;
self.write_expr(module, init, func_ctx)?;
}

// Finish the local with `;` and add a newline (only for readability)
Expand Down
22 changes: 10 additions & 12 deletions src/front/glsl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -984,18 +984,16 @@ impl<'a> Context<'a> {
// pointer type which is required for dynamic indexing
if !constant_index {
if let Some((constant, ty)) = var.constant {
let local =
self.locals.append(
LocalVariable {
name: None,
ty,
init: Some(self.module.const_expressions.append(
Expression::Constant(constant),
Span::default(),
)),
},
Span::default(),
);
let init = self
.add_expression(Expression::Constant(constant), Span::default())?;
let local = self.locals.append(
LocalVariable {
name: None,
ty,
init: Some(init),
},
Span::default(),
);

self.add_expression(Expression::LocalVariable(local), Span::default())?
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1399,7 +1399,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let init_id = self.next()?;
let lconst = self.lookup_constant.lookup(init_id)?;
Some(
ctx.const_expressions
ctx.expressions
.append(crate::Expression::Constant(lconst.handle), span),
)
} else {
Expand Down
10 changes: 9 additions & 1 deletion src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1159,11 +1159,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
};

let (const_initializer, initializer) = {
match initializer {
Some(init) if ctx.naga_expressions.is_const(init) => (Some(init), None),
Some(init) => (None, Some(init)),
None => (None, None),
}
};

let var = ctx.variables.append(
crate::LocalVariable {
name: Some(v.name.name.to_string()),
ty,
init: None,
init: const_initializer,
},
stmt.span,
);
Expand Down
4 changes: 1 addition & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -964,9 +964,7 @@ pub struct LocalVariable {
pub name: Option<String>,
/// The type of this variable.
pub ty: Handle<Type>,
/// Initial value for this variable.
///
/// Expression handle lives in const_expressions
/// Initial value for this variable. Must be a const-expression.
pub init: Option<Handle<Expression>>,
}

Expand Down
14 changes: 14 additions & 0 deletions src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ pub enum ExprType {
// Math
// As

// TODO(teoxoy): consider accumulating this metadata instead of recursing through subexpressions
impl Arena<Expression> {
pub fn is_const(&self, handle: Handle<Expression>) -> bool {
match self[handle] {
Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => true,
Expression::Compose { ref components, .. } => {
components.iter().all(|h| self.is_const(*h))
}
Expression::Splat { ref value, .. } => self.is_const(*value),
_ => false,
}
}
}

impl<'a, F: FnMut(&mut Arena<Expression>, Expression, Span) -> Handle<Expression>>
ConstantEvaluator<'a, F>
{
Expand Down
Loading

0 comments on commit 6f64c99

Please sign in to comment.