Skip to content

Commit

Permalink
[naga msl-out] Defeat the MSL compiler's infinite loop analysis.
Browse files Browse the repository at this point in the history
See the comments in the code for details.

This patch emits the definition of the macro only when the first loop
is encountered. This does make that first loop's code look a bit odd:
it would be more natural to define the macro at the top of the
file. (See the modified files in `naga/tests/out/msl`.)

Rejected alternatives:

- We could emit the macro definition unconditionally at the top of the
  file. But this changes every MSL snapshot output file, whereas only
  eight of them actually contain loops.

- We could have the validator flag modules that contain loops. But the
  changes end up being not small, and spread across the validator, so
  this seems disproportionate. If we had other consumers of this
  information, it might make sense.

- We could change the MSL backend to allow text to be generated out of
  order, so that we can decide whether to define the macro after we've
  generated all the function bodies. But at the moment this seems like
  unnecessary complexity, although it might be worth doing in the
  future if we had additional uses for it - say, to conditionally emit
  helper function definitions.

Fixes #4972.
  • Loading branch information
jimblandy authored and ErichDonGubler committed Sep 18, 2024
1 parent c3ab12a commit 3fda684
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 19 deletions.
139 changes: 137 additions & 2 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,11 @@ pub struct Writer<W> {
/// Set of (struct type, struct field index) denoting which fields require
/// padding inserted **before** them (i.e. between fields at index - 1 and index)
struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,

/// Name of the loop reachability macro.
///
/// See `emit_loop_reachable_macro` for details.
loop_reachable_macro_name: String,
}

impl crate::Scalar {
Expand Down Expand Up @@ -665,6 +670,7 @@ impl<W: Write> Writer<W> {
#[cfg(test)]
put_block_stack_pointers: Default::default(),
struct_member_pads: FastHashSet::default(),
loop_reachable_macro_name: String::default(),
}
}

Expand All @@ -675,6 +681,125 @@ impl<W: Write> Writer<W> {
self.out
}

/// Define a macro to invoke before loops, to defeat MSL infinite loop
/// reasoning.
///
/// If we haven't done so already, emit the definition of a preprocessor
/// macro to be invoked before each loop in the generated MSL, to ensure
/// that the MSL compiler's optimizations do not remove bounds checks.
///
/// Only the first call to this function for a given module actually causes
/// the macro definition to be written. Subsequent loops can simply use the
/// prior macro definition, since macros aren't block-scoped.
///
/// # What is this trying to solve?
///
/// In Metal Shading Language, an infinite loop has undefined behavior.
/// (This rule is inherited from C++14.) This means that, if the MSL
/// compiler determines that a given loop will never exit, it may assume
/// that it is never reached. It may thus assume that any conditions
/// sufficient to cause the loop to be reached must be false. Like many
/// optimizing compilers, MSL uses this kind of analysis to establish limits
/// on the range of values variables involved in those conditions might
/// hold.
///
/// For example, suppose the MSL compiler sees the code:
///
/// ```ignore
/// if (i >= 10) {
/// while (true) { }
/// }
/// ```
///
/// It will recognize that the `while` loop will never terminate, conclude
/// that it must be unreachable, and thus infer that, if this code is
/// reached, then `i < 10` at that point.
///
/// Now suppose that, at some point where `i` has the same value as above,
/// the compiler sees the code:
///
/// ```ignore
/// if (i < 10) {
/// a[i] = 1;
/// }
/// ```
///
/// Because the compiler is confident that `i < 10`, it will make the
/// assignment to `a[i]` unconditional, rewriting this code as, simply:
///
/// ```ignore
/// a[i] = 1;
/// ```
///
/// If that `if` condition was injected by Naga to implement a bounds check,
/// the MSL compiler's optimizations could allow out-of-bounds array
/// accesses to occur.
///
/// Naga cannot feasibly anticipate whether the MSL compiler will determine
/// that a loop is infinite, so an attacker could craft a Naga module
/// containing an infinite loop protected by conditions that cause the Metal
/// compiler to remove bounds checks that Naga injected elsewhere in the
/// function.
///
/// This rewrite could occur even if the conditional assignment appears
/// *before* the `while` loop, as long as `i < 10` by the time the loop is
/// reached. This would allow the attacker to save the results of
/// unauthorized reads somewhere accessible before entering the infinite
/// loop. But even worse, the MSL compiler has been observed to simply
/// delete the infinite loop entirely, so that even code dominated by the
/// loop becomes reachable. This would make the attack even more flexible,
/// since shaders that would appear to never terminate would actually exit
/// nicely, after having stolen data from elsewhere in the GPU address
/// space.
///
/// Ideally, Naga would prevent UB entirely via some means that persuades
/// the MSL compiler that no loop Naga generates is infinite. One approach
/// would be to add inline assembly to each loop that is annotated as
/// potentially branching out of the loop, but which in fact generates no
/// instructions. Unfortunately, inline assembly is not handled correctly by
/// some Metal device drivers. Further experimentation hasn't produced a
/// satisfactory approach.
///
/// Instead, we accept that the MSL compiler may determine that some loops
/// are infinite, and focus instead on preventing the range analysis from
/// being affected. We transform *every* loop into something like this:
///
/// ```ignore
/// if (volatile bool unpredictable = true; unpredictable)
/// while (true) { }
/// ```
///
/// Since the `volatile` qualifier prevents the compiler from assuming that
/// the `if` condition is true, it cannot be sure the infinite loop is
/// reached, and thus it cannot assume the entire structure is unreachable.
/// This prevents the range analysis impact described above.
///
/// Unfortunately, what makes this a kludge, not a hack, is that this
/// solution leaves the GPU executing a pointless conditional branch, at
/// runtime, before each loop. There's no part of the system that has a
/// global enough view to be sure that `unpredictable` is true, and remove
/// it from the code.
///
/// To make our output a bit more legible, we pull the condition out into a
/// preprocessor macro defined at the top of the module.
fn emit_loop_reachable_macro(&mut self) -> BackendResult {
if !self.loop_reachable_macro_name.is_empty() {
return Ok(());
}

self.loop_reachable_macro_name = self.namer.call("LOOP_IS_REACHABLE");
let loop_reachable_volatile_name = self.namer.call("unpredictable_jump_over_loop");
writeln!(
self.out,
"#define {} if (volatile bool {} = true; {})",
self.loop_reachable_macro_name,
loop_reachable_volatile_name,
loop_reachable_volatile_name,
)?;

Ok(())
}

fn put_call_parameters(
&mut self,
parameters: impl Iterator<Item = Handle<crate::Expression>>,
Expand Down Expand Up @@ -2924,10 +3049,15 @@ impl<W: Write> Writer<W> {
ref continuing,
break_if,
} => {
self.emit_loop_reachable_macro()?;
if !continuing.is_empty() || break_if.is_some() {
let gate_name = self.namer.call("loop_init");
writeln!(self.out, "{level}bool {gate_name} = true;")?;
writeln!(self.out, "{level}while(true) {{")?;
writeln!(
self.out,
"{level}{} while(true) {{",
self.loop_reachable_macro_name,
)?;
let lif = level.next();
let lcontinuing = lif.next();
writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
Expand All @@ -2942,7 +3072,11 @@ impl<W: Write> Writer<W> {
writeln!(self.out, "{lif}}}")?;
writeln!(self.out, "{lif}{gate_name} = false;")?;
} else {
writeln!(self.out, "{level}while(true) {{")?;
writeln!(
self.out,
"{level}{} while(true) {{",
self.loop_reachable_macro_name,
)?;
}
self.put_block(level.next(), body, context)?;
writeln!(self.out, "{level}}}")?;
Expand Down Expand Up @@ -3379,6 +3513,7 @@ impl<W: Write> Writer<W> {
&[CLAMPED_LOD_LOAD_PREFIX],
&mut self.names,
);
self.loop_reachable_macro_name.clear();
self.struct_member_pads.clear();

writeln!(
Expand Down
3 changes: 2 additions & 1 deletion naga/tests/out/msl/boids.msl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ kernel void main_(
vPos = _e8;
metal::float2 _e14 = particlesSrc.particles[index].vel;
vVel = _e14;
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true;
while(true) {
LOOP_IS_REACHABLE while(true) {
if (!loop_init) {
uint _e91 = i;
i = _e91 + 1u;
Expand Down
9 changes: 5 additions & 4 deletions naga/tests/out/msl/break-if.msl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ using metal::uint;

void breakIfEmpty(
) {
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true;
while(true) {
LOOP_IS_REACHABLE while(true) {
if (!loop_init) {
if (true) {
break;
Expand All @@ -25,7 +26,7 @@ void breakIfEmptyBody(
bool b = {};
bool c = {};
bool loop_init_1 = true;
while(true) {
LOOP_IS_REACHABLE while(true) {
if (!loop_init_1) {
b = a;
bool _e2 = b;
Expand All @@ -46,7 +47,7 @@ void breakIf(
bool d = {};
bool e = {};
bool loop_init_2 = true;
while(true) {
LOOP_IS_REACHABLE while(true) {
if (!loop_init_2) {
bool _e5 = e;
if (a_1 == e) {
Expand All @@ -65,7 +66,7 @@ void breakIfSeparateVariable(
) {
uint counter = 0u;
bool loop_init_3 = true;
while(true) {
LOOP_IS_REACHABLE while(true) {
if (!loop_init_3) {
uint _e5 = counter;
if (counter == 5u) {
Expand Down
3 changes: 2 additions & 1 deletion naga/tests/out/msl/collatz.msl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ uint collatz_iterations(
uint n = {};
uint i = 0u;
n = n_base;
while(true) {
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
LOOP_IS_REACHABLE while(true) {
uint _e4 = n;
if (_e4 > 1u) {
} else {
Expand Down
13 changes: 7 additions & 6 deletions naga/tests/out/msl/control-flow.msl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ void switch_case_break(
void loop_switch_continue(
int x
) {
while(true) {
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
LOOP_IS_REACHABLE while(true) {
switch(x) {
case 1: {
continue;
Expand All @@ -49,7 +50,7 @@ void loop_switch_continue_nesting(
int y,
int z
) {
while(true) {
LOOP_IS_REACHABLE while(true) {
switch(x_1) {
case 1: {
continue;
Expand All @@ -60,7 +61,7 @@ void loop_switch_continue_nesting(
continue;
}
default: {
while(true) {
LOOP_IS_REACHABLE while(true) {
switch(z) {
case 1: {
continue;
Expand All @@ -85,7 +86,7 @@ void loop_switch_continue_nesting(
}
}
}
while(true) {
LOOP_IS_REACHABLE while(true) {
switch(y) {
case 1:
default: {
Expand All @@ -108,7 +109,7 @@ void loop_switch_omit_continue_variable_checks(
int w
) {
int pos_1 = 0;
while(true) {
LOOP_IS_REACHABLE while(true) {
switch(x_2) {
case 1: {
pos_1 = 1;
Expand All @@ -119,7 +120,7 @@ void loop_switch_omit_continue_variable_checks(
}
}
}
while(true) {
LOOP_IS_REACHABLE while(true) {
switch(x_2) {
case 1: {
break;
Expand Down
3 changes: 2 additions & 1 deletion naga/tests/out/msl/do-while.msl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ using metal::uint;
void fb1_(
thread bool& cond
) {
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true;
while(true) {
LOOP_IS_REACHABLE while(true) {
if (!loop_init) {
bool _e1 = cond;
if (!(cond)) {
Expand Down
3 changes: 2 additions & 1 deletion naga/tests/out/msl/overrides-ray-query.msl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ kernel void main_(
rq.intersector.force_opacity((desc.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (desc.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
rq.intersector.accept_any_intersection((desc.flags & 4) != 0);
rq.intersection = rq.intersector.intersect(metal::raytracing::ray(desc.origin, desc.dir, desc.tmin, desc.tmax), acc_struct, desc.cull_mask); rq.ready = true;
while(true) {
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
LOOP_IS_REACHABLE while(true) {
bool _e31 = rq.ready;
rq.ready = false;
if (_e31) {
Expand Down
3 changes: 2 additions & 1 deletion naga/tests/out/msl/ray-query.msl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ RayIntersection query_loop(
rq.intersector.force_opacity((_e8.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e8.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
rq.intersector.accept_any_intersection((_e8.flags & 4) != 0);
rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e8.origin, _e8.dir, _e8.tmin, _e8.tmax), acs, _e8.cull_mask); rq.ready = true;
while(true) {
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
LOOP_IS_REACHABLE while(true) {
bool _e9 = rq.ready;
rq.ready = false;
if (_e9) {
Expand Down
5 changes: 3 additions & 2 deletions naga/tests/out/msl/shadow.msl
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ fragment fs_mainOutput fs_main(
metal::float3 color = c_ambient;
uint i = 0u;
metal::float3 normal_1 = metal::normalize(in.world_normal);
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true;
while(true) {
LOOP_IS_REACHABLE while(true) {
if (!loop_init) {
uint _e40 = i;
i = _e40 + 1u;
Expand Down Expand Up @@ -151,7 +152,7 @@ fragment fs_main_without_storageOutput fs_main_without_storage(
uint i_1 = 0u;
metal::float3 normal_2 = metal::normalize(in_1.world_normal);
bool loop_init_1 = true;
while(true) {
LOOP_IS_REACHABLE while(true) {
if (!loop_init_1) {
uint _e40 = i_1;
i_1 = _e40 + 1u;
Expand Down

0 comments on commit 3fda684

Please sign in to comment.