Skip to content

Commit

Permalink
[msl] Fix host-shareable struct padding
Browse files Browse the repository at this point in the history
The struct usages are not kept up-to-date through transforms, so we
need to determine which structs are used in host-shareable address
spaces by inspecting the entry point parameters instead.

Bug: 42251016
Change-Id: Ia88720e52674e212b4f889de9c7ee45f026da3bc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/203635
Commit-Queue: James Price <[email protected]>
Reviewed-by: dan sinclair <[email protected]>
  • Loading branch information
jrprice authored and Dawn LUCI CQ committed Aug 22, 2024
1 parent b2d106e commit 520a86a
Show file tree
Hide file tree
Showing 385 changed files with 2,378 additions and 1,528 deletions.
38 changes: 37 additions & 1 deletion src/tint/lang/msl/writer/printer/printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ class Printer : public tint::TextGenerator {
// Module-scope declarations should have all been moved into the entry points.
TINT_ASSERT(ir_.root_block->IsEmpty());

// Determine which structures will need to be emitted with host-shareable memory layouts.
FindHostShareableStructs();

// Emit functions.
for (auto* func : ir_.DependencyOrderedFunctions()) {
EmitFunction(func);
Expand Down Expand Up @@ -169,6 +172,7 @@ class Printer : public tint::TextGenerator {
/// Non-empty only if an invariant attribute has been generated.
std::string invariant_define_name_;

Hashset<const core::type::Struct*, 16> host_shareable_structs_;
std::unordered_set<const core::type::Struct*> emitted_structs_;

/// The current function being emitted
Expand Down Expand Up @@ -213,6 +217,38 @@ class Printer : public tint::TextGenerator {
return array_template_name_;
}

/// Find all structures that are used in host-shareable address spaces and mark them as such so
/// that we know to pad the properly when we emit them.
void FindHostShareableStructs() {
// We only look at function parameters of entry points, since this is how binding resources
// are handled in MSL.
for (auto func : ir_.functions) {
if (func->Stage() == core::ir::Function::PipelineStage::kUndefined) {
continue;
}
for (auto* param : func->Params()) {
auto* ptr = param->Type()->As<core::type::Pointer>();
if (ptr && core::IsHostShareable(ptr->AddressSpace())) {
// Look for structures at any nesting depth of this parameter's type.
Vector<const core::type::Type*, 8> type_queue;
type_queue.Push(ptr->StoreType());
while (!type_queue.IsEmpty()) {
auto* next = type_queue.Pop();
if (auto* str = next->As<core::type::Struct>()) {
// Record this structure as host-shareable.
host_shareable_structs_.Add(str);
for (auto* member : str->Members()) {
type_queue.Push(member->Type());
}
} else if (auto* arr = next->As<core::type::Array>()) {
type_queue.Push(arr->ElemType());
}
}
}
}
}
}

/// Check if a value is emitted as an actual pointer (instead of a reference).
/// @param value the value to check
/// @returns true if @p value will be emitted as an actual pointer
Expand Down Expand Up @@ -1363,7 +1399,7 @@ class Printer : public tint::TextGenerator {
TextBuffer str_buf;
Line(&str_buf) << "\n" << "struct " << StructName(str) << " {";

bool is_host_shareable = str->IsHostShareable();
bool is_host_shareable = host_shareable_structs_.Contains(str);

// Emits a `/* 0xnnnn */` byte offset comment for a struct member.
auto add_byte_offset_comment = [&](StringStream& out, uint32_t offset) {
Expand Down
98 changes: 49 additions & 49 deletions src/tint/lang/msl/writer/type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ TEST_F(MslWriterTest, EmitType_Struct_Layout_NonComposites) {
{mod.symbols.Register("z"), ty.f32()}};

auto* s = MkStruct(mod, ty, "S", data);
s->AddUsage(core::AddressSpace::kStorage);

// ALL_FIELDS() calls the macro FIELD(ADDR, TYPE, ARRAY_COUNT, NAME)
// for each field of the structure s.
Expand Down Expand Up @@ -532,20 +531,20 @@ TEST_F(MslWriterTest, EmitType_Struct_Layout_NonComposites) {
ALL_FIELDS()
#undef FIELD
expect << R"(};
void foo() {
thread S a = {};
}
)";

auto* func = b.Function("foo", ty.void_());
auto* var = b.Var("a", ty.ptr(core::AddressSpace::kStorage, s));
var->SetBindingPoint(0, 0);
mod.root_block->Append(var);
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute,
std::array<uint32_t, 3>{1u, 1u, 1u});
b.Append(func->Block(), [&] {
b.Var("a", ty.ptr(core::AddressSpace::kPrivate, s));
b.Load(var);
b.Return(func);
});

ASSERT_TRUE(Generate()) << err_ << output_.msl;
EXPECT_EQ(output_.msl, expect.str());
EXPECT_THAT(output_.msl, testing::HasSubstr(expect.str()));

// 1.4 Metal and C++14
// The Metal programming language is a C++14-based Specification with
Expand Down Expand Up @@ -587,7 +586,6 @@ TEST_F(MslWriterTest, EmitType_Struct_Layout_Structures) {
{mod.symbols.Register("c"), ty.f32()},
{mod.symbols.Register("d"), inner_y},
{mod.symbols.Register("e"), ty.f32()}});
const_cast<core::type::Struct*>(s)->AddUsage(core::AddressSpace::kStorage);

// ALL_FIELDS() calls the macro FIELD(ADDR, TYPE, ARRAY_COUNT, NAME)
// for each field of the structure s.
Expand All @@ -598,19 +596,22 @@ TEST_F(MslWriterTest, EmitType_Struct_Layout_Structures) {
FIELD(0x0600, float, 0, c) \
FIELD(0x0604, inner_y, 0, d) \
FIELD(0x0808, float, 0, e) \
FIELD(0x080c, int8_t, 500, tint_pad_1)
FIELD(0x080c, int8_t, 500, tint_pad_4)

// Check that the generated string is as expected.
StringStream expect;
expect << MetalHeader() << MetalArray() << R"(
struct inner_x {
int a;
float b;
/* 0x0000 */ int a;
/* 0x0004 */ tint_array<int8_t, 508> tint_pad_1;
/* 0x0200 */ float b;
/* 0x0204 */ tint_array<int8_t, 508> tint_pad_2;
};
struct inner_y {
int a;
float b;
/* 0x0000 */ int a;
/* 0x0004 */ tint_array<int8_t, 508> tint_pad_3;
/* 0x0200 */ float b;
};
)";
Expand All @@ -621,20 +622,20 @@ struct inner_y {
ALL_FIELDS()
#undef FIELD
expect << R"(};
void foo() {
thread S a = {};
}
)";

auto* func = b.Function("foo", ty.void_());
auto* var = b.Var("a", ty.ptr(core::AddressSpace::kStorage, s));
var->SetBindingPoint(0, 0);
mod.root_block->Append(var);
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute,
std::array<uint32_t, 3>{1u, 1u, 1u});
b.Append(func->Block(), [&] {
b.Var("a", ty.ptr(core::AddressSpace::kPrivate, s));
b.Load(var);
b.Return(func);
});

ASSERT_TRUE(Generate()) << err_ << output_.msl;
EXPECT_EQ(output_.msl, expect.str());
EXPECT_THAT(output_.msl, testing::HasSubstr(expect.str()));

// 1.4 Metal and C++14
// The Metal programming language is a C++14-based Specification with
Expand Down Expand Up @@ -694,7 +695,6 @@ TEST_F(MslWriterTest, EmitType_Struct_Layout_ArrayDefaultStride) {
{mod.symbols.Register("d"), array_y},
{mod.symbols.Register("e"), ty.f32()},
{mod.symbols.Register("f"), array_z}});
const_cast<core::type::Struct*>(s)->AddUsage(core::AddressSpace::kStorage);

// ALL_FIELDS() calls the macro FIELD(ADDR, TYPE, ARRAY_COUNT, NAME)
// for each field of the structure s.
Expand All @@ -706,15 +706,17 @@ TEST_F(MslWriterTest, EmitType_Struct_Layout_ArrayDefaultStride) {
FIELD(0x0200, inner, 4, d) \
FIELD(0x1200, float, 0, e) \
FIELD(0x1204, float, 1, f) \
FIELD(0x1208, int8_t, 504, tint_pad_1)
FIELD(0x1208, int8_t, 504, tint_pad_3)

// Check that the generated string is as expected.
StringStream expect;

expect << MetalHeader() << MetalArray() << R"(
struct inner {
int a;
float b;
/* 0x0000 */ int a;
/* 0x0004 */ tint_array<int8_t, 508> tint_pad_1;
/* 0x0200 */ float b;
/* 0x0204 */ tint_array<int8_t, 508> tint_pad_2;
};
)";
Expand All @@ -725,20 +727,20 @@ struct inner {
ALL_FIELDS()
#undef FIELD
expect << R"(};
void foo() {
thread S a = {};
}
)";

auto* func = b.Function("foo", ty.void_());
auto* var = b.Var("a", ty.ptr(core::AddressSpace::kStorage, s));
var->SetBindingPoint(0, 0);
mod.root_block->Append(var);
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute,
std::array<uint32_t, 3>{1u, 1u, 1u});
b.Append(func->Block(), [&] {
b.Var("a", ty.ptr(core::AddressSpace::kPrivate, s));
b.Load(var);
b.Return(func);
});

ASSERT_TRUE(Generate()) << err_ << output_.msl;
EXPECT_EQ(output_.msl, expect.str());
EXPECT_THAT(output_.msl, testing::HasSubstr(expect.str()));

// 1.4 Metal and C++14
// The Metal programming language is a C++14-based Specification with
Expand Down Expand Up @@ -791,7 +793,6 @@ TEST_F(MslWriterTest, EmitType_Struct_Layout_ArrayVec3DefaultStride) {
{mod.symbols.Register("b"), array},
{mod.symbols.Register("c"), ty.i32()},
});
const_cast<core::type::Struct*>(s)->AddUsage(core::AddressSpace::kStorage);

// ALL_FIELDS() calls the macro FIELD(ADDR, TYPE, ARRAY_COUNT, NAME)
// for each field of the structure s.
Expand All @@ -811,20 +812,20 @@ TEST_F(MslWriterTest, EmitType_Struct_Layout_ArrayVec3DefaultStride) {
ALL_FIELDS()
#undef FIELD
expect << R"(};
void foo() {
thread S a = {};
}
)";

auto* func = b.Function("foo", ty.void_());
auto* var = b.Var("a", ty.ptr(core::AddressSpace::kStorage, s));
var->SetBindingPoint(0, 0);
mod.root_block->Append(var);
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute,
std::array<uint32_t, 3>{1u, 1u, 1u});
b.Append(func->Block(), [&] {
b.Var("a", ty.ptr(core::AddressSpace::kPrivate, s));
b.Load(var);
b.Return(func);
});

ASSERT_TRUE(Generate()) << err_ << output_.msl;
EXPECT_EQ(output_.msl, expect.str());
EXPECT_THAT(output_.msl, testing::HasSubstr(expect.str()));
}

TEST_F(MslWriterTest, AttemptTintPadSymbolCollision) {
Expand Down Expand Up @@ -857,7 +858,6 @@ TEST_F(MslWriterTest, AttemptTintPadSymbolCollision) {
{mod.symbols.Register("tint_pad_21"), ty.f32()}};

auto* s = MkStruct(mod, ty, "S", data);
s->AddUsage(core::AddressSpace::kStorage);

auto expect = MetalHeader() + MetalArray() + R"(
struct S {
Expand Down Expand Up @@ -901,20 +901,20 @@ struct S {
/* 0x0300 */ float tint_pad_21;
/* 0x0304 */ tint_array<int8_t, 124> tint_pad_38;
};
void foo() {
thread S a = {};
}
)";

auto* func = b.Function("foo", ty.void_());
auto* var = b.Var("a", ty.ptr(core::AddressSpace::kStorage, s));
var->SetBindingPoint(0, 0);
mod.root_block->Append(var);
auto* func = b.Function("foo", ty.void_(), core::ir::Function::PipelineStage::kCompute,
std::array<uint32_t, 3>{1u, 1u, 1u});
b.Append(func->Block(), [&] {
b.Var("a", ty.ptr(core::AddressSpace::kPrivate, s));
b.Load(var);
b.Return(func);
});

ASSERT_TRUE(Generate()) << err_ << output_.msl;
EXPECT_EQ(output_.msl, expect);
EXPECT_THAT(output_.msl, testing::HasSubstr(expect));
}

TEST_F(MslWriterTest, EmitType_Sampler) {
Expand Down Expand Up @@ -1075,7 +1075,7 @@ using MslWriterStorageTexturesTest = MslWriterTestWithParam<MslStorageTextureDat
TEST_P(MslWriterStorageTexturesTest, Emit) {
auto params = GetParam();

auto* f32 = const_cast<core::type::F32*>(ty.f32());
auto* f32 = ty.f32();
auto s = ty.Get<core::type::StorageTexture>(params.dim, core::TexelFormat::kR32Float,
core::Access::kWrite, f32);
auto* func = b.Function("foo", ty.void_());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct tint_array {
};

struct S {
tint_array<int4, 4> arr;
/* 0x0000 */ tint_array<int4, 4> arr;
};

struct tint_module_vars_struct {
Expand Down
2 changes: 1 addition & 1 deletion test/tint/array/assign_to_private_var.wgsl.expected.ir.msl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct tint_array {
};

struct S {
tint_array<int4, 4> arr;
/* 0x0000 */ tint_array<int4, 4> arr;
};

struct tint_module_vars_struct {
Expand Down
4 changes: 2 additions & 2 deletions test/tint/array/assign_to_storage_var.wgsl.expected.ir.msl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ struct tint_array {
};

struct S {
tint_array<int4, 4> arr;
/* 0x0000 */ tint_array<int4, 4> arr;
};

struct S_nested {
tint_array<tint_array<tint_array<int, 2>, 3>, 4> arr;
/* 0x0000 */ tint_array<tint_array<tint_array<int, 2>, 3>, 4> arr;
};

struct tint_module_vars_struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct tint_array {
};

struct S {
tint_array<int4, 4> arr;
/* 0x0000 */ tint_array<int4, 4> arr;
};

struct tint_module_vars_struct {
Expand Down
14 changes: 8 additions & 6 deletions test/tint/array/strides.spvasm.expected.ir.msl
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
#include <metal_stdlib>
using namespace metal;

struct strided_arr {
float el;
};

template<typename T, size_t N>
struct tint_array {
const constant T& operator[](size_t i) const constant { return elements[i]; }
Expand All @@ -17,12 +13,18 @@ struct tint_array {
T elements[N];
};

struct strided_arr {
/* 0x0000 */ float el;
/* 0x0004 */ tint_array<int8_t, 4> tint_pad;
};

struct strided_arr_1 {
tint_array<tint_array<strided_arr, 2>, 3> el;
/* 0x0000 */ tint_array<tint_array<strided_arr, 2>, 3> el;
/* 0x0030 */ tint_array<int8_t, 80> tint_pad_1;
};

struct S {
tint_array<strided_arr_1, 4> a;
/* 0x0000 */ tint_array<strided_arr_1, 4> a;
};

struct tint_module_vars_struct {
Expand Down
Loading

0 comments on commit 520a86a

Please sign in to comment.