Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wgsl-in] consolidate type resolution logic in a few macros #2571

Merged
merged 3 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/front/wgsl/lower/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
[component] => {
let span = ctx.ast_expressions.get_span(component);
let component = self.expression(component, ctx.reborrow())?;
ctx.grow_types(component)?;
let ty = &ctx.typifier()[component];
let ty = super::resolve!(ctx, component);

ComponentsHandle::One {
component,
Expand All @@ -178,8 +177,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)
.collect();

ctx.grow_types(component)?;
let ty = &ctx.typifier()[component];
let ty = super::resolve!(ctx, component);

ComponentsHandle::Many {
components,
Expand Down
102 changes: 50 additions & 52 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,32 @@ use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span};

mod construction;

macro_rules! resolve {
($ctx:ident, $expr:expr) => {{
$ctx.grow_types($expr)?;
&$ctx.typifier()[$expr]
}};
}
pub(super) use resolve;

macro_rules! resolve_inner {
($ctx:ident, $expr:expr) => {{
$ctx.grow_types($expr)?;
$ctx.typifier()[$expr].inner_with(&$ctx.module.types)
}};
}

macro_rules! resolve_inner_x2 {
teoxoy marked this conversation as resolved.
Show resolved Hide resolved
($ctx:ident, $left:expr, $right:expr) => {{
$ctx.grow_types($left)?;
$ctx.grow_types($right)?;
(
$ctx.typifier()[$left].inner_with(&$ctx.module.types),
$ctx.typifier()[$right].inner_with(&$ctx.module.types),
)
}};
}

/// State for constructing a `crate::Module`.
pub struct GlobalContext<'source, 'temp, 'out> {
/// The `TranslationUnit`'s expressions arena.
Expand Down Expand Up @@ -460,15 +486,13 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {

/// Determine the type of `handle`, and add it to the module's arena.
///
/// If you just need a `TypeInner` for `handle`'s type, use
/// [`grow_types`] and [`resolved_inner`] instead. This function
/// If you just need a `TypeInner` for `handle`'s type, use the
/// [`resolve_inner!`] macro instead. This function
/// should only be used when the type of `handle` needs to appear
/// in the module's final `Arena<Type>`, for example, if you're
/// creating a [`LocalVariable`] whose type is inferred from its
/// initializer.
///
/// [`grow_types`]: Self::grow_types
/// [`resolved_inner`]: Self::resolved_inner
/// [`LocalVariable`]: crate::LocalVariable
fn register_type(
&mut self,
Expand Down Expand Up @@ -498,12 +522,11 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
/// return a shared reference to the resulting `TypeResolution`:
/// the shared reference would extend the mutable borrow, and you
/// wouldn't be able to use `self` for anything else. Instead, you
/// should call `grow_types` to cover the handles you need, and
/// then use `self.typifier[handle]` or
/// [`self.resolved_inner(handle)`] to get at their resolutions.
/// should use [`register_type`] or one of [`resolve!`],
/// [`resolve_inner!`] or [`resolve_inner_x2!`].
///
/// [`self.typifier`]: ExpressionContext::typifier
/// [`self.resolved_inner(handle)`]: ExpressionContext::resolved_inner
/// [`register_type`]: Self::register_type
/// [`Typifier`]: Typifier
fn grow_types(
&mut self,
Expand Down Expand Up @@ -533,17 +556,12 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
Ok(self)
}

fn resolved_inner(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
self.typifier()[handle].inner_with(&self.module.types)
}

fn image_data(
&mut self,
image: Handle<crate::Expression>,
span: Span,
) -> Result<(crate::ImageClass, bool), Error<'source>> {
self.grow_types(image)?;
match *self.resolved_inner(image) {
match *resolve_inner!(self, image) {
crate::TypeInner::Image { class, arrayed, .. } => Ok((class, arrayed)),
_ => Err(Error::BadTexture(span)),
}
Expand Down Expand Up @@ -584,9 +602,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
| crate::BinaryOperator::Divide
| crate::BinaryOperator::Modulo
) {
self.grow_types(*left)?.grow_types(*right)?;

match (self.resolved_inner(*left), self.resolved_inner(*right)) {
match resolve_inner_x2!(self, *left, *right) {
(&crate::TypeInner::Vector { size, .. }, &crate::TypeInner::Scalar { .. }) => {
*right = self.append_expression(
crate::Expression::Splat {
Expand Down Expand Up @@ -1146,11 +1162,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

let initializer = match v.init {
Some(init) => {
let initializer =
self.expression(init, ctx.as_expression(block, &mut emitter))?;
ctx.as_expression(block, &mut emitter)
.grow_types(initializer)?;
teoxoy marked this conversation as resolved.
Show resolved Hide resolved
Some(initializer)
Some(self.expression(init, ctx.as_expression(block, &mut emitter))?)
}
None => None,
};
Expand All @@ -1161,8 +1173,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

let ty = match (explicit_ty, initializer) {
(Some(explicit), Some(initializer)) => {
let ctx = ctx.as_expression(block, &mut emitter);
let initializer_ty = ctx.resolved_inner(initializer);
let mut ctx = ctx.as_expression(block, &mut emitter);
let initializer_ty = resolve_inner!(ctx, initializer);
if !ctx.module.types[explicit]
.inner
.equivalent(initializer_ty, &ctx.module.types)
Expand Down Expand Up @@ -1266,9 +1278,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let mut ectx = ctx.as_expression(block, &mut emitter);
let selector = self.expression(selector, ectx.reborrow())?;

ectx.grow_types(selector)?;
let uint =
ectx.resolved_inner(selector).scalar_kind() == Some(crate::ScalarKind::Uint);
resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint);
block.extend(emitter.finish(ctx.naga_expressions));

let cases = cases
Expand Down Expand Up @@ -1407,8 +1418,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
self.expression_for_reference(value, ctx.as_expression(block, &mut emitter))?;
let mut ectx = ctx.as_expression(block, &mut emitter);

ectx.grow_types(reference.handle)?;
let (kind, width) = match *ectx.resolved_inner(reference.handle) {
let (kind, width) = match *resolve_inner!(ectx, reference.handle) {
crate::TypeInner::ValuePointer {
size: None,
kind,
Expand Down Expand Up @@ -1553,8 +1563,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// The pointer we dereference must be loaded.
let pointer = self.expression(expr, ctx.reborrow())?;

ctx.grow_types(pointer)?;
if ctx.resolved_inner(pointer).pointer_space().is_none() {
if resolve_inner!(ctx, pointer).pointer_space().is_none() {
return Err(Error::NotPointer(span));
}

Expand Down Expand Up @@ -1583,9 +1592,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let expr = self.expression_for_reference(base, ctx.reborrow())?;
let index = self.expression(index, ctx.reborrow())?;

ctx.grow_types(expr.handle)?;
let wgsl_pointer =
ctx.resolved_inner(expr.handle).pointer_space().is_some() && !expr.is_reference;
let wgsl_pointer = resolve_inner!(ctx, expr.handle).pointer_space().is_some()
&& !expr.is_reference;

if wgsl_pointer {
return Err(Error::Pointer(
Expand Down Expand Up @@ -1618,9 +1626,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
is_reference,
} = self.expression_for_reference(base, ctx.reborrow())?;

ctx.grow_types(handle)?;
let temp_inner;
let (composite, wgsl_pointer) = match *ctx.resolved_inner(handle) {
let (composite, wgsl_pointer) = match *resolve_inner!(ctx, handle) {
crate::TypeInner::Pointer { base, .. } => {
(&ctx.module.types[base].inner, !is_reference)
}
Expand Down Expand Up @@ -1707,8 +1714,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
crate::TypeInner::Scalar { kind, .. } => kind,
crate::TypeInner::Vector { kind, .. } => kind,
_ => {
ctx.grow_types(expr)?;
let ty = &ctx.typifier()[expr];
let ty = resolve!(ctx, expr);
return Err(Error::BadTypeCast {
from_type: ctx.format_type_resolution(ty),
span: ty_span,
Expand Down Expand Up @@ -1814,9 +1820,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
fun,
crate::RelationalFunction::All | crate::RelationalFunction::Any
) && {
ctx.grow_types(argument)?;
matches!(
ctx.resolved_inner(argument),
resolve_inner!(ctx, argument),
&crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
..
Expand Down Expand Up @@ -1859,8 +1864,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
args.finish()?;

if fun == crate::MathFunction::Modf || fun == crate::MathFunction::Frexp {
ctx.grow_types(arg)?;
if let Some((size, width)) = match *ctx.resolved_inner(arg) {
if let Some((size, width)) = match *resolve_inner!(ctx, arg) {
crate::TypeInner::Scalar { width, .. } => Some((None, width)),
crate::TypeInner::Vector { size, width, .. } => {
Some((Some(size), width))
Expand Down Expand Up @@ -2005,11 +2009,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let value = args.next()?;
let value_span = ctx.ast_expressions.get_span(value);
let value = self.expression(value, ctx.reborrow())?;
ctx.grow_types(value)?;

args.finish()?;

let expression = match *ctx.resolved_inner(value) {
let expression = match *resolve_inner!(ctx, value) {
crate::TypeInner::Scalar { kind, width } => {
crate::Expression::AtomicResult {
ty: ctx.module.generate_predeclared_type(
Expand Down Expand Up @@ -2061,8 +2064,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
args.finish()?;

let pointer = self.expression(expr, ctx.reborrow())?;
ctx.grow_types(pointer)?;
let result_ty = match *ctx.resolved_inner(pointer) {
let result_ty = match *resolve_inner!(ctx, pointer) {
crate::TypeInner::Pointer {
base,
space: crate::AddressSpace::WorkGroup,
Expand All @@ -2083,7 +2085,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
span,
);

ctx.grow_types(pointer)?;
return Ok(Some(result));
}
"textureStore" => {
Expand Down Expand Up @@ -2280,8 +2281,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let span = ctx.ast_expressions.get_span(expr);
let pointer = self.expression(expr, ctx.reborrow())?;

ctx.grow_types(pointer)?;
match *ctx.resolved_inner(pointer) {
match *resolve_inner!(ctx, pointer) {
crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {
crate::TypeInner::Atomic { .. } => Ok(pointer),
ref other => {
Expand Down Expand Up @@ -2361,8 +2361,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let lowered_image_or_component =
self.expression(image_or_component, ctx.reborrow())?;

ctx.grow_types(lowered_image_or_component)?;
match *ctx.resolved_inner(lowered_image_or_component) {
match *resolve_inner!(ctx, lowered_image_or_component) {
crate::TypeInner::Image {
class: crate::ImageClass::Depth { .. },
..
Expand Down Expand Up @@ -2683,8 +2682,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let span = ctx.ast_expressions.get_span(expr);
let pointer = self.expression(expr, ctx.reborrow())?;

ctx.grow_types(pointer)?;
match *ctx.resolved_inner(pointer) {
match *resolve_inner!(ctx, pointer) {
crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {
crate::TypeInner::RayQuery => Ok(pointer),
ref other => {
Expand Down