Skip to content

Commit

Permalink
disallow ptr to workgroup fn arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy authored and jimblandy committed Oct 17, 2023
1 parent ea83f62 commit 6854b0a
Show file tree
Hide file tree
Showing 12 changed files with 425 additions and 450 deletions.
7 changes: 1 addition & 6 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1001,12 +1001,7 @@ impl super::Validator {
#[cfg(feature = "validate")]
for (index, argument) in fun.arguments.iter().enumerate() {
match module.types[argument.ty].inner.pointer_space() {
Some(
crate::AddressSpace::Private
| crate::AddressSpace::Function
| crate::AddressSpace::WorkGroup,
)
| None => {}
Some(crate::AddressSpace::Private | crate::AddressSpace::Function) | None => {}
Some(other) => {
return Err(FunctionError::InvalidArgumentPointerSpace {
index,
Expand Down
14 changes: 8 additions & 6 deletions src/valid/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,14 @@ fn check_member_layout(
/// `TypeFlags::empty()`.
///
/// Pointers passed as arguments to user-defined functions must be in the
/// `Function`, `Private`, or `Workgroup` storage space.
/// `Function` or `Private` address space.
const fn ptr_space_argument_flag(space: crate::AddressSpace) -> TypeFlags {
use crate::AddressSpace as As;
match space {
As::Function | As::Private | As::WorkGroup => TypeFlags::ARGUMENT,
As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant => TypeFlags::empty(),
As::Function | As::Private => TypeFlags::ARGUMENT,
As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant | As::WorkGroup => {
TypeFlags::empty()
}
}
}

Expand Down Expand Up @@ -316,7 +318,7 @@ impl super::Validator {
return Err(TypeError::InvalidPointerBase(base));
}

// Runtime-sized values can only live in the `Storage` storage
// Runtime-sized values can only live in the `Storage` address
// space, so it's useless to have a pointer to such a type in
// any other space.
//
Expand All @@ -336,7 +338,7 @@ impl super::Validator {
}
}

// `Validator::validate_function` actually checks the storage
// `Validator::validate_function` actually checks the address
// space of pointer arguments explicitly before checking the
// `ARGUMENT` flag, to give better error messages. But it seems
// best to set `ARGUMENT` accurately anyway.
Expand Down Expand Up @@ -364,7 +366,7 @@ impl super::Validator {
// `InvalidPointerBase` or `InvalidPointerToUnsized`.
self.check_width(kind, width)?;

// `Validator::validate_function` actually checks the storage
// `Validator::validate_function` actually checks the address
// space of pointer arguments explicitly before checking the
// `ARGUMENT` flag, to give better error messages. But it seems
// best to set `ARGUMENT` accurately anyway.
Expand Down
9 changes: 4 additions & 5 deletions tests/in/access.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ fn foo_frag() -> @location(0) vec4<f32> {
return vec4<f32>(0.0);
}

var<workgroup> val: u32;

fn assign_through_ptr_fn(p: ptr<workgroup, u32>) {
fn assign_through_ptr_fn(p: ptr<function, u32>) {
*p = 42u;
}

Expand All @@ -163,8 +161,9 @@ fn assign_array_through_ptr_fn(foo: ptr<function, array<vec4<f32>, 2>>) {

@compute @workgroup_size(1)
fn assign_through_ptr() {
var arr = array<vec4<f32>, 2>(vec4(6.0), vec4(7.0));

var val = 33u;
assign_through_ptr_fn(&val);

var arr = array<vec4<f32>, 2>(vec4(6.0), vec4(7.0));
assign_array_through_ptr_fn(&arr);
}
47 changes: 25 additions & 22 deletions tests/out/analysis/access.info.ron
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
("READ"),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -1144,7 +1143,6 @@
(""),
(""),
("READ"),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2414,7 +2412,6 @@
(""),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2454,7 +2451,6 @@
(""),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2503,7 +2499,6 @@
(""),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2546,7 +2541,6 @@
(""),
(""),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -2638,7 +2632,6 @@
("READ"),
("READ"),
("READ"),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -3302,7 +3295,6 @@
(""),
("WRITE"),
(""),
(""),
],
expressions: [
(
Expand Down Expand Up @@ -3736,9 +3728,32 @@
(""),
(""),
(""),
("READ"),
],
expressions: [
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar(
kind: Uint,
width: 4,
)),
),
(
uniformity: (
non_uniform_result: Some(2),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Pointer(
base: 1,
space: Function,
)),
),
(
uniformity: (
non_uniform_result: None,
Expand Down Expand Up @@ -3800,7 +3815,7 @@
),
(
uniformity: (
non_uniform_result: Some(6),
non_uniform_result: Some(8),
requirements: (""),
),
ref_count: 1,
Expand All @@ -3810,18 +3825,6 @@
space: Function,
)),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: Some(6),
ty: Value(Pointer(
base: 1,
space: WorkGroup,
)),
),
],
sampling: [],
dual_source_blending: false,
Expand Down
8 changes: 1 addition & 7 deletions tests/out/glsl/access.assign_through_ptr.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ struct Baz {
struct MatCx2InArray {
mat4x2 am[2];
};
shared uint val;


float read_from_private(inout float foo_1) {
float _e1 = foo_1;
Expand All @@ -42,11 +40,7 @@ void assign_array_through_ptr_fn(inout vec4 foo_2[2]) {
}

void main() {
if (gl_LocalInvocationID == uvec3(0u)) {
val = 0u;
}
memoryBarrierShared();
barrier();
uint val = 33u;
vec4 arr[2] = vec4[2](vec4(6.0), vec4(7.0));
assign_through_ptr_fn(val);
assign_array_through_ptr_fn(arr);
Expand Down
8 changes: 2 additions & 6 deletions tests/out/hlsl/access.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ RWByteAddressBuffer bar : register(u0);
cbuffer baz : register(b1) { Baz baz; }
RWByteAddressBuffer qux : register(u2);
cbuffer nested_mat_cx2_ : register(b3) { MatCx2InArray nested_mat_cx2_; }
groupshared uint val;

Baz ConstructBaz(float3x2 arg0) {
Baz ret = (Baz)0;
Expand Down Expand Up @@ -288,12 +287,9 @@ float4 foo_frag() : SV_Target0
}

[numthreads(1, 1, 1)]
void assign_through_ptr(uint3 __local_invocation_id : SV_GroupThreadID)
void assign_through_ptr()
{
if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {
val = (uint)0;
}
GroupMemoryBarrierWithGroupSync();
uint val = 33u;
float4 arr[2] = Constructarray2_float4_((6.0).xxxx, (7.0).xxxx);

assign_through_ptr_fn(val);
Expand Down
47 changes: 23 additions & 24 deletions tests/out/ir/access.compact.ron
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@
name: None,
inner: Pointer(
base: 1,
space: WorkGroup,
space: Function,
),
),
(
Expand Down Expand Up @@ -356,13 +356,6 @@
ty: 20,
init: None,
),
(
name: Some("val"),
space: WorkGroup,
binding: None,
ty: 1,
init: None,
),
],
const_expressions: [
Literal(U32(0)),
Expand Down Expand Up @@ -2137,54 +2130,60 @@
arguments: [],
result: None,
local_variables: [
(
name: Some("val"),
ty: 1,
init: Some(1),
),
(
name: Some("arr"),
ty: 28,
init: Some(5),
init: Some(7),
),
],
expressions: [
Literal(U32(33)),
LocalVariable(1),
Literal(F32(6.0)),
Splat(
size: Quad,
value: 1,
value: 3,
),
Literal(F32(7.0)),
Splat(
size: Quad,
value: 3,
value: 5,
),
Compose(
ty: 28,
components: [
2,
4,
6,
],
),
LocalVariable(1),
GlobalVariable(6),
LocalVariable(2),
],
named_expressions: {},
body: [
Emit((
start: 1,
end: 2,
)),
Emit((
start: 3,
end: 5,
)),
Call(
function: 5,
arguments: [
7,
2,
],
result: None,
),
Emit((
start: 3,
end: 4,
)),
Emit((
start: 5,
end: 7,
)),
Call(
function: 6,
arguments: [
6,
8,
],
result: None,
),
Expand Down
Loading

0 comments on commit 6854b0a

Please sign in to comment.