Skip to content

Commit

Permalink
[wgsl-in] consolidate type resolution logic in a few macros (#2571)
Browse files Browse the repository at this point in the history
* [wgsl-in] consolidate type resolution logic in a few macros

* rename + docs

* reorder macros (avoids doc linking not working)
  • Loading branch information
teoxoy authored Oct 19, 2023
1 parent 29ca531 commit 19209b6
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 56 deletions.
6 changes: 2 additions & 4 deletions src/front/wgsl/lower/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
[component] => {
let span = ctx.ast_expressions.get_span(component);
let component = self.expression(component, ctx.reborrow())?;
ctx.grow_types(component)?;
let ty = &ctx.typifier()[component];
let ty = super::resolve!(ctx, component);

ComponentsHandle::One {
component,
Expand All @@ -178,8 +177,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)
.collect();

ctx.grow_types(component)?;
let ty = &ctx.typifier()[component];
let ty = super::resolve!(ctx, component);

ComponentsHandle::Many {
components,
Expand Down
129 changes: 77 additions & 52 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,59 @@ use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span};

mod construction;

/// Resolves the inner type of a given expression.
///
/// Expects a &mut [`ExpressionContext`] and a [`Handle<Expression>`].
///
/// Returns a &[`crate::TypeInner`].
///
/// Ideally, we would simply have a function that takes a `&mut ExpressionContext`
/// and returns a `&TypeResolution`. Unfortunately, this leads the borrow checker
/// to conclude that the mutable borrow lasts for as long as we are using the
/// `&TypeResolution`, so we can't use the `ExpressionContext` for anything else -
/// like, say, resolving another operand's type. Using a macro that expands to
/// two separate calls, only the first of which needs a `&mut`,
/// lets the borrow checker see that the mutable borrow is over.
macro_rules! resolve_inner {
($ctx:ident, $expr:expr) => {{
$ctx.grow_types($expr)?;
$ctx.typifier()[$expr].inner_with(&$ctx.module.types)
}};
}

/// Resolves the inner types of two given expressions.
///
/// Expects a &mut [`ExpressionContext`] and two [`Handle<Expression>`]s.
///
/// Returns a tuple containing two &[`crate::TypeInner`].
///
/// See the documentation of [`resolve_inner!`] for why this macro is necessary.
macro_rules! resolve_inner_binary {
($ctx:ident, $left:expr, $right:expr) => {{
$ctx.grow_types($left)?;
$ctx.grow_types($right)?;
(
$ctx.typifier()[$left].inner_with(&$ctx.module.types),
$ctx.typifier()[$right].inner_with(&$ctx.module.types),
)
}};
}

/// Resolves the type of a given expression.
///
/// Expects a &mut [`ExpressionContext`] and a [`Handle<Expression>`].
///
/// Returns a &[`TypeResolution`].
///
/// See the documentation of [`resolve_inner!`] for why this macro is necessary.
macro_rules! resolve {
($ctx:ident, $expr:expr) => {{
$ctx.grow_types($expr)?;
&$ctx.typifier()[$expr]
}};
}
pub(super) use resolve;

/// State for constructing a `crate::Module`.
pub struct GlobalContext<'source, 'temp, 'out> {
/// The `TranslationUnit`'s expressions arena.
Expand Down Expand Up @@ -460,15 +513,13 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {

/// Determine the type of `handle`, and add it to the module's arena.
///
/// If you just need a `TypeInner` for `handle`'s type, use
/// [`grow_types`] and [`resolved_inner`] instead. This function
/// If you just need a `TypeInner` for `handle`'s type, use the
/// [`resolve_inner!`] macro instead. This function
/// should only be used when the type of `handle` needs to appear
/// in the module's final `Arena<Type>`, for example, if you're
/// creating a [`LocalVariable`] whose type is inferred from its
/// initializer.
///
/// [`grow_types`]: Self::grow_types
/// [`resolved_inner`]: Self::resolved_inner
/// [`LocalVariable`]: crate::LocalVariable
fn register_type(
&mut self,
Expand Down Expand Up @@ -498,12 +549,11 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
/// return a shared reference to the resulting `TypeResolution`:
/// the shared reference would extend the mutable borrow, and you
/// wouldn't be able to use `self` for anything else. Instead, you
/// should call `grow_types` to cover the handles you need, and
/// then use `self.typifier[handle]` or
/// [`self.resolved_inner(handle)`] to get at their resolutions.
/// should use [`register_type`] or one of [`resolve!`],
/// [`resolve_inner!`] or [`resolve_inner_binary!`].
///
/// [`self.typifier`]: ExpressionContext::typifier
/// [`self.resolved_inner(handle)`]: ExpressionContext::resolved_inner
/// [`register_type`]: Self::register_type
/// [`Typifier`]: Typifier
fn grow_types(
&mut self,
Expand Down Expand Up @@ -533,17 +583,12 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
Ok(self)
}

fn resolved_inner(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
self.typifier()[handle].inner_with(&self.module.types)
}

fn image_data(
&mut self,
image: Handle<crate::Expression>,
span: Span,
) -> Result<(crate::ImageClass, bool), Error<'source>> {
self.grow_types(image)?;
match *self.resolved_inner(image) {
match *resolve_inner!(self, image) {
crate::TypeInner::Image { class, arrayed, .. } => Ok((class, arrayed)),
_ => Err(Error::BadTexture(span)),
}
Expand Down Expand Up @@ -584,9 +629,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
| crate::BinaryOperator::Divide
| crate::BinaryOperator::Modulo
) {
self.grow_types(*left)?.grow_types(*right)?;

match (self.resolved_inner(*left), self.resolved_inner(*right)) {
match resolve_inner_binary!(self, *left, *right) {
(&crate::TypeInner::Vector { size, .. }, &crate::TypeInner::Scalar { .. }) => {
*right = self.append_expression(
crate::Expression::Splat {
Expand Down Expand Up @@ -1146,11 +1189,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

let initializer = match v.init {
Some(init) => {
let initializer =
self.expression(init, ctx.as_expression(block, &mut emitter))?;
ctx.as_expression(block, &mut emitter)
.grow_types(initializer)?;
Some(initializer)
Some(self.expression(init, ctx.as_expression(block, &mut emitter))?)
}
None => None,
};
Expand All @@ -1161,8 +1200,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

let ty = match (explicit_ty, initializer) {
(Some(explicit), Some(initializer)) => {
let ctx = ctx.as_expression(block, &mut emitter);
let initializer_ty = ctx.resolved_inner(initializer);
let mut ctx = ctx.as_expression(block, &mut emitter);
let initializer_ty = resolve_inner!(ctx, initializer);
if !ctx.module.types[explicit]
.inner
.equivalent(initializer_ty, &ctx.module.types)
Expand Down Expand Up @@ -1266,9 +1305,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let mut ectx = ctx.as_expression(block, &mut emitter);
let selector = self.expression(selector, ectx.reborrow())?;

ectx.grow_types(selector)?;
let uint =
ectx.resolved_inner(selector).scalar_kind() == Some(crate::ScalarKind::Uint);
resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint);
block.extend(emitter.finish(ctx.naga_expressions));

let cases = cases
Expand Down Expand Up @@ -1407,8 +1445,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
self.expression_for_reference(value, ctx.as_expression(block, &mut emitter))?;
let mut ectx = ctx.as_expression(block, &mut emitter);

ectx.grow_types(reference.handle)?;
let (kind, width) = match *ectx.resolved_inner(reference.handle) {
let (kind, width) = match *resolve_inner!(ectx, reference.handle) {
crate::TypeInner::ValuePointer {
size: None,
kind,
Expand Down Expand Up @@ -1553,8 +1590,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// The pointer we dereference must be loaded.
let pointer = self.expression(expr, ctx.reborrow())?;

ctx.grow_types(pointer)?;
if ctx.resolved_inner(pointer).pointer_space().is_none() {
if resolve_inner!(ctx, pointer).pointer_space().is_none() {
return Err(Error::NotPointer(span));
}

Expand Down Expand Up @@ -1583,9 +1619,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let expr = self.expression_for_reference(base, ctx.reborrow())?;
let index = self.expression(index, ctx.reborrow())?;

ctx.grow_types(expr.handle)?;
let wgsl_pointer =
ctx.resolved_inner(expr.handle).pointer_space().is_some() && !expr.is_reference;
let wgsl_pointer = resolve_inner!(ctx, expr.handle).pointer_space().is_some()
&& !expr.is_reference;

if wgsl_pointer {
return Err(Error::Pointer(
Expand Down Expand Up @@ -1618,9 +1653,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
is_reference,
} = self.expression_for_reference(base, ctx.reborrow())?;

ctx.grow_types(handle)?;
let temp_inner;
let (composite, wgsl_pointer) = match *ctx.resolved_inner(handle) {
let (composite, wgsl_pointer) = match *resolve_inner!(ctx, handle) {
crate::TypeInner::Pointer { base, .. } => {
(&ctx.module.types[base].inner, !is_reference)
}
Expand Down Expand Up @@ -1707,8 +1741,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
crate::TypeInner::Scalar { kind, .. } => kind,
crate::TypeInner::Vector { kind, .. } => kind,
_ => {
ctx.grow_types(expr)?;
let ty = &ctx.typifier()[expr];
let ty = resolve!(ctx, expr);
return Err(Error::BadTypeCast {
from_type: ctx.format_type_resolution(ty),
span: ty_span,
Expand Down Expand Up @@ -1814,9 +1847,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
fun,
crate::RelationalFunction::All | crate::RelationalFunction::Any
) && {
ctx.grow_types(argument)?;
matches!(
ctx.resolved_inner(argument),
resolve_inner!(ctx, argument),
&crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
..
Expand Down Expand Up @@ -1859,8 +1891,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
args.finish()?;

if fun == crate::MathFunction::Modf || fun == crate::MathFunction::Frexp {
ctx.grow_types(arg)?;
if let Some((size, width)) = match *ctx.resolved_inner(arg) {
if let Some((size, width)) = match *resolve_inner!(ctx, arg) {
crate::TypeInner::Scalar { width, .. } => Some((None, width)),
crate::TypeInner::Vector { size, width, .. } => {
Some((Some(size), width))
Expand Down Expand Up @@ -2005,11 +2036,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let value = args.next()?;
let value_span = ctx.ast_expressions.get_span(value);
let value = self.expression(value, ctx.reborrow())?;
ctx.grow_types(value)?;

args.finish()?;

let expression = match *ctx.resolved_inner(value) {
let expression = match *resolve_inner!(ctx, value) {
crate::TypeInner::Scalar { kind, width } => {
crate::Expression::AtomicResult {
ty: ctx.module.generate_predeclared_type(
Expand Down Expand Up @@ -2061,8 +2091,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
args.finish()?;

let pointer = self.expression(expr, ctx.reborrow())?;
ctx.grow_types(pointer)?;
let result_ty = match *ctx.resolved_inner(pointer) {
let result_ty = match *resolve_inner!(ctx, pointer) {
crate::TypeInner::Pointer {
base,
space: crate::AddressSpace::WorkGroup,
Expand All @@ -2083,7 +2112,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
span,
);

ctx.grow_types(pointer)?;
return Ok(Some(result));
}
"textureStore" => {
Expand Down Expand Up @@ -2280,8 +2308,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let span = ctx.ast_expressions.get_span(expr);
let pointer = self.expression(expr, ctx.reborrow())?;

ctx.grow_types(pointer)?;
match *ctx.resolved_inner(pointer) {
match *resolve_inner!(ctx, pointer) {
crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {
crate::TypeInner::Atomic { .. } => Ok(pointer),
ref other => {
Expand Down Expand Up @@ -2361,8 +2388,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let lowered_image_or_component =
self.expression(image_or_component, ctx.reborrow())?;

ctx.grow_types(lowered_image_or_component)?;
match *ctx.resolved_inner(lowered_image_or_component) {
match *resolve_inner!(ctx, lowered_image_or_component) {
crate::TypeInner::Image {
class: crate::ImageClass::Depth { .. },
..
Expand Down Expand Up @@ -2683,8 +2709,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let span = ctx.ast_expressions.get_span(expr);
let pointer = self.expression(expr, ctx.reborrow())?;

ctx.grow_types(pointer)?;
match *ctx.resolved_inner(pointer) {
match *resolve_inner!(ctx, pointer) {
crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {
crate::TypeInner::RayQuery => Ok(pointer),
ref other => {
Expand Down

0 comments on commit 19209b6

Please sign in to comment.