Skip to content

Commit

Permalink
[spirv] Move symbol renaming into the backend
Browse files Browse the repository at this point in the history
Instead of running the AST renamer transform in Dawn's Vulkan backend,
we can just ignore all the names when emitting SPIR-V.

Add an extra option so that Dawn can specify the entry point name to
use for the OpEntryPoint instruction.

Bug: 380043958
Change-Id: I62546be79bf2adaa34cee2550758f5941475f5bd
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/218616
Commit-Queue: James Price <[email protected]>
Reviewed-by: dan sinclair <[email protected]>
  • Loading branch information
jrprice authored and Dawn LUCI CQ committed Dec 9, 2024
1 parent 644bf00 commit b8d9de5
Show file tree
Hide file tree
Showing 13 changed files with 241 additions and 171 deletions.
23 changes: 6 additions & 17 deletions src/dawn/native/vulkan/ShaderModuleVk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ ShaderModule::~ShaderModule() = default;
X(std::optional<tint::ast::transform::SubstituteOverride::Config>, substituteOverrideConfig) \
X(LimitsForCompilationRequest, limits) \
X(std::string_view, entryPointName) \
X(bool, disableSymbolRenaming) \
X(tint::spirv::writer::Options, tintOptions) \
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, platform)

Expand Down Expand Up @@ -338,9 +337,12 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
auto tintProgram = GetTintProgram();
req.inputProgram = &(tintProgram->program);
req.entryPointName = programmableStage.entryPoint;
req.disableSymbolRenaming = GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming);
req.platform = UnsafeUnkeyedValue(GetDevice()->GetPlatform());
req.substituteOverrideConfig = std::move(substituteOverrideConfig);

req.tintOptions.remapped_entry_point_name = kRemappedEntryPointName;
req.tintOptions.strip_all_names = !GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming);

req.tintOptions.statically_paired_texture_binding_points =
std::move(statically_paired_texture_binding_points);
req.tintOptions.clamp_frag_depth = clampFragDepth;
Expand Down Expand Up @@ -387,23 +389,10 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
tint::ast::transform::DataMap transformInputs;

// Many Vulkan drivers can't handle multi-entrypoint shader modules.
// Run before the renamer so that the entry point name matches `entryPointName` still.
transformManager.append(std::make_unique<tint::ast::transform::SingleEntryPoint>());
transformInputs.Add<tint::ast::transform::SingleEntryPoint::Config>(
std::string(r.entryPointName));

// Rename symbols unless symbol renaming is disabled.
std::string remappedEntryPointName = std::string(r.entryPointName);
if (!r.disableSymbolRenaming) {
constexpr char kRemappedEntryPointName[] = "dawn_entry_point";
tint::ast::transform::Renamer::Remappings requestedNames = {
{remappedEntryPointName, kRemappedEntryPointName}};
transformManager.Add<tint::ast::transform::Renamer>();
transformInputs.Add<tint::ast::transform::Renamer::Config>(
tint::ast::transform::Renamer::Target::kAll, std::move(requestedNames));
remappedEntryPointName = kRemappedEntryPointName;
}

if (r.substituteOverrideConfig) {
// This needs to run after SingleEntryPoint transform which removes unused overrides
// for current entry point.
Expand All @@ -425,7 +414,7 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
if (r.stage == SingleShaderStage::Compute) {
Extent3D _;
DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize(
program, remappedEntryPointName.c_str(), r.limits));
program, r.entryPointName.data(), r.limits));
}

TRACE_EVENT0(r.platform.UnsafeGetValue(), General, "tint::spirv::writer::Generate()");
Expand All @@ -443,7 +432,7 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(

CompiledSpirv result;
result.spirv = std::move(tintResult.Get().spirv);
result.remappedEntryPoint = remappedEntryPointName;
result.remappedEntryPoint = kRemappedEntryPointName;
return result;
},
"Vulkan.CompileShaderToSPIRV");
Expand Down
3 changes: 3 additions & 0 deletions src/dawn/native/vulkan/ShaderModuleVk.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ struct ProgrammableStage;

namespace vulkan {

// The entry point name to use when generating SPIR-V.
constexpr char kRemappedEntryPointName[] = "dawn_entry_point";

struct TransformedShaderModuleCacheKey {
uintptr_t layoutPtr;
std::string entryPoint;
Expand Down
6 changes: 6 additions & 0 deletions src/tint/cmd/tint/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,8 @@ When specified, automatically enables HLSL validation)",
}
case Format::kSpirv:
case Format::kSpvAsm:
// Renaming is handled in the backend.
break;
case Format::kWgsl:
case Format::kIr: {
if (options.rename_all) {
Expand Down Expand Up @@ -762,6 +764,10 @@ bool GenerateSpirv([[maybe_unused]] Options& options,
}

tint::spirv::writer::Options gen_options;
if (options.rename_all) {
gen_options.remapped_entry_point_name = "tint_entry_point";
gen_options.strip_all_names = true;
}
gen_options.disable_robustness = !options.enable_robustness;
gen_options.disable_workgroup_init = options.disable_workgroup_init;
gen_options.use_storage_input_output_16 = options.use_storage_input_output_16;
Expand Down
10 changes: 10 additions & 0 deletions src/tint/lang/spirv/writer/common/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#ifndef SRC_TINT_LANG_SPIRV_WRITER_COMMON_OPTIONS_H_
#define SRC_TINT_LANG_SPIRV_WRITER_COMMON_OPTIONS_H_

#include <optional>
#include <string>
#include <unordered_map>
#include <unordered_set>

Expand Down Expand Up @@ -129,6 +131,9 @@ struct Bindings {

/// Configuration options used for generating SPIR-V.
struct Options {
/// An optional remapped name to use when emitting the entry point.
std::optional<std::string> remapped_entry_point_name;

/// The bindings
Bindings bindings;

Expand All @@ -138,6 +143,9 @@ struct Options {
// the samplers with which they are paired.
std::unordered_set<BindingPoint> statically_paired_texture_binding_points = {};

/// Set to `true` to strip all user-declared identifiers from the module.
bool strip_all_names = false;

/// Set to `true` to disable software robustness that prevents out-of-bounds accesses.
bool disable_robustness = false;

Expand Down Expand Up @@ -181,8 +189,10 @@ struct Options {

/// Reflect the fields of this class so that it can be used by tint::ForeachField()
TINT_REFLECT(Options,
remapped_entry_point_name,
bindings,
statically_paired_texture_binding_points,
strip_all_names,
disable_robustness,
disable_image_robustness,
disable_runtime_sized_array_index_clamping,
Expand Down
58 changes: 37 additions & 21 deletions src/tint/lang/spirv/writer/printer/printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include "src/tint/lang/spirv/writer/printer/printer.h"

#include <string>
#include <utility>

#include "spirv/unified1/GLSL.std.450.h"
Expand Down Expand Up @@ -379,9 +380,7 @@ class Printer {
auto id = Constant(constant->Value());

// Set the name for the SPIR-V result ID if provided in the module.
if (auto name = ir_.NameOf(constant)) {
module_.PushDebug(spv::Op::OpName, {id, Operand(name.Name())});
}
PushName(id, constant);

return id;
}
Expand Down Expand Up @@ -594,10 +593,7 @@ class Printer {
Operand(effective_row_count * matrix_type->Type()->Size())});
}

if (member->Name().IsValid()) {
module_.PushDebug(spv::Op::OpMemberName,
{operands[0], member->Index(), Operand(member->Name().Name())});
}
PushMemberName(id, member->Index(), member->Name());
}
module_.PushType(spv::Op::OpTypeStruct, std::move(operands));

Expand All @@ -606,9 +602,7 @@ class Printer {
module_.PushAnnot(spv::Op::OpDecorate, {id, U32Operand(SpvDecorationBlock)});
}

if (str->Name().IsValid()) {
module_.PushDebug(spv::Op::OpName, {operands[0], Operand(str->Name().Name())});
}
PushName(id, str->Name());
}

/// Emit a texture type.
Expand Down Expand Up @@ -710,7 +704,7 @@ class Printer {
auto id = Value(func);

// Emit the function name.
module_.PushDebug(spv::Op::OpName, {id, Operand(ir_.NameOf(func).Name())});
PushName(id, func);

// Emit OpEntryPoint and OpExecutionMode declarations if needed.
if (func->Stage() != core::ir::Function::PipelineStage::kUndefined) {
Expand All @@ -730,9 +724,7 @@ class Printer {
auto param_id = Value(param);
params.push_back(Instruction(spv::Op::OpFunctionParameter, {param_type_id, param_id}));
function_type.param_type_ids.Push(param_type_id);
if (auto name = ir_.NameOf(param)) {
module_.PushDebug(spv::Op::OpName, {param_id, Operand(name.Name())});
}
PushName(param_id, param);
}

// Get the ID for the function type (creating it if needed).
Expand Down Expand Up @@ -796,7 +788,16 @@ class Printer {
TINT_ICE() << "undefined pipeline stage for entry point";
}

OperandList operands = {U32Operand(stage), id, ir_.NameOf(func).Name()};
// Use the remapped entry point name if requested, otherwise use the original name.
std::string name;
if (options_.remapped_entry_point_name) {
name = *options_.remapped_entry_point_name;
} else {
name = ir_.NameOf(func).Name();
}
TINT_ASSERT(!name.empty());

OperandList operands = {U32Operand(stage), id, name};

// Add the list of all referenced shader IO variables.
for (auto* global : *ir_.root_block) {
Expand Down Expand Up @@ -928,9 +929,7 @@ class Printer {

// Set the name for the SPIR-V result ID if provided in the module.
if (inst->Result(0) && !inst->Is<core::ir::Var>()) {
if (auto name = ir_.NameOf(inst)) {
module_.PushDebug(spv::Op::OpName, {Value(inst), Operand(name.Name())});
}
PushName(Value(inst), inst);
}
}
}
Expand Down Expand Up @@ -2388,9 +2387,7 @@ class Printer {
}

// Set the name if present.
if (auto name = ir_.NameOf(var)) {
module_.PushDebug(spv::Op::OpName, {id, Operand(name.Name())});
}
PushName(id, var);
}

/// Emit a let instruction.
Expand Down Expand Up @@ -2519,6 +2516,25 @@ class Printer {
}
return SpvImageFormatUnknown;
}

/// Set the debug name of an instruction.
void PushName(uint32_t id, core::ir::Instruction* inst) { PushName(id, ir_.NameOf(inst)); }
/// Set the debug name of a value.
void PushName(uint32_t id, core::ir::Value* value) { PushName(id, ir_.NameOf(value)); }
/// Set the debug name for a SPIR-V ID.
void PushName(uint32_t id, const Symbol& name) {
// Only set the name if it is valid and if we are not stripping user identifiers.
if (name && !options_.strip_all_names) {
module_.PushDebug(spv::Op::OpName, {id, Operand(name.Name())});
}
}
/// Set the debug member name for a SPIR-V ID.
void PushMemberName(uint32_t id, uint32_t index, const Symbol& name) {
// Only set the name if it is valid and if we are not stripping user identifiers.
if (name && !options_.strip_all_names) {
module_.PushDebug(spv::Op::OpMemberName, {id, index, Operand(name.Name())});
}
}
};

} // namespace
Expand Down
1 change: 1 addition & 0 deletions src/tint/lang/spirv/writer/raise/raise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <utility>

#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/core/ir/transform/add_empty_entry_point.h"
#include "src/tint/lang/core/ir/transform/bgra8unorm_polyfill.h"
#include "src/tint/lang/core/ir/transform/binary_polyfill.h"
Expand Down
89 changes: 89 additions & 0 deletions src/tint/lang/spirv/writer/writer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
namespace tint::spirv::writer {
namespace {

using namespace tint::core::fluent_types; // NOLINT
using namespace tint::core::number_suffixes; // NOLINT

TEST_F(SpirvWriterTest, ModuleHeader) {
Expand Down Expand Up @@ -118,5 +119,93 @@ TEST_F(SpirvWriterTest, TooManyFunctionParameters) {
"Function 'foo' has more than 255 parameters after running Tint transforms"));
}

TEST_F(SpirvWriterTest, EntryPointName_Remapped) {
auto* func = b.ComputeFunction("main");
b.Append(func->Block(), [&] { //
b.Return(func);
});

Options options;
options.remapped_entry_point_name = "my_entry_point";
ASSERT_TRUE(Generate(options)) << Error() << output_;
EXPECT_INST("OpEntryPoint GLCompute %main \"my_entry_point\"");
}

TEST_F(SpirvWriterTest, EntryPointName_NotRemapped) {
auto* func = b.ComputeFunction("main");
b.Append(func->Block(), [&] { //
b.Return(func);
});

Options options;
options.remapped_entry_point_name = {};
ASSERT_TRUE(Generate(options)) << Error() << output_;
EXPECT_INST("OpEntryPoint GLCompute %main \"main\"");
}

TEST_F(SpirvWriterTest, StripAllNames) {
auto* str =
ty.Struct(mod.symbols.New("MyStruct"), {
{mod.symbols.Register("a"), ty.i32()},
{mod.symbols.Register("b"), ty.vec4<i32>()},
});
auto* func = b.ComputeFunction("main");
auto* idx = b.FunctionParam("idx", ty.u32());
idx->SetBuiltin(core::BuiltinValue::kLocalInvocationIndex);
func->AppendParam(idx);
b.Append(func->Block(), [&] { //
auto* var = b.Var("str", ty.ptr<function>(str));
auto* val = b.Load(var);
mod.SetName(val, "val");
auto* a = b.Access<i32>(val, 0_u);
mod.SetName(a, "a");
b.Return(func);
});

Options options;
options.strip_all_names = true;
options.remapped_entry_point_name = "tint_entry_point";
ASSERT_TRUE(Generate(options)) << Error() << output_;
EXPECT_INST(R"(
OpEntryPoint GLCompute %16 "tint_entry_point" %gl_LocalInvocationIndex
OpExecutionMode %16 LocalSize 1 1 1
; Annotations
OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex
OpMemberDecorate %_struct_11 0 Offset 0
OpMemberDecorate %_struct_11 1 Offset 16
; Types, variables and constants
%uint = OpTypeInt 32 0
%_ptr_Input_uint = OpTypePointer Input %uint
%gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input ; BuiltIn LocalInvocationIndex
%void = OpTypeVoid
%7 = OpTypeFunction %void %uint
%int = OpTypeInt 32 1
%v4int = OpTypeVector %int 4
%_struct_11 = OpTypeStruct %int %v4int
%_ptr_Function__struct_11 = OpTypePointer Function %_struct_11
%14 = OpConstantNull %_struct_11
%17 = OpTypeFunction %void
; Function 4
%4 = OpFunction %void None %7
%6 = OpFunctionParameter %uint
%8 = OpLabel
%9 = OpVariable %_ptr_Function__struct_11 Function %14
%15 = OpLoad %_struct_11 %9 None
OpReturn
OpFunctionEnd
; Function 16
%16 = OpFunction %void None %17
%18 = OpLabel
%19 = OpLoad %uint %gl_LocalInvocationIndex None
%20 = OpFunctionCall %void %4 %19
OpReturn
OpFunctionEnd
)");
}

} // namespace
} // namespace tint::spirv::writer
13 changes: 5 additions & 8 deletions test/tint/bug/chromium/1236161.wgsl.expected.spvasm
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,17 @@
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %tint_symbol "tint_symbol"
OpName %tint_symbol_1 "tint_symbol_1"
OpName %unused_entry_point "unused_entry_point"
OpEntryPoint GLCompute %7 "tint_entry_point"
OpExecutionMode %7 LocalSize 1 1 1
%void = OpTypeVoid
%3 = OpTypeFunction %void
%float = OpTypeFloat 32
%tint_symbol_1 = OpConstant %float 1
%tint_symbol = OpFunction %void None %3
%float_1 = OpConstant %float 1
%1 = OpFunction %void None %3
%4 = OpLabel
OpReturn
OpFunctionEnd
%unused_entry_point = OpFunction %void None %3
%7 = OpFunction %void None %3
%8 = OpLabel
OpReturn
OpFunctionEnd
Loading

0 comments on commit b8d9de5

Please sign in to comment.