Skip to content

Commit

Permalink
Update SPIR-V Generator to return workgroup information.
Browse files Browse the repository at this point in the history
This CL adds an output to the SPIR-V generator to return the workgroup
size information after `SubstituteOverrides` has been run.

Bug: 380043961
Change-Id: Id7f2b203afca3ab1b29ad43ec70b83444ac0445c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/218699
Auto-Submit: dan sinclair <[email protected]>
Commit-Queue: dan sinclair <[email protected]>
Reviewed-by: James Price <[email protected]>
  • Loading branch information
dj2 authored and Dawn LUCI CQ committed Dec 10, 2024
1 parent c190829 commit 6f3bd8a
Show file tree
Hide file tree
Showing 15 changed files with 61 additions and 36 deletions.
2 changes: 0 additions & 2 deletions src/tint/lang/spirv/writer/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,9 @@ load("@bazel_skylib//lib:selects.bzl", "selects")
cc_library(
name = "writer",
srcs = [
"output.cc",
"writer.cc",
],
hdrs = [
"output.h",
"writer.h",
],
deps = [
Expand Down
2 changes: 0 additions & 2 deletions src/tint/lang/spirv/writer/BUILD.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ if(TINT_BUILD_SPV_WRITER)
# Condition: TINT_BUILD_SPV_WRITER
################################################################################
tint_add_target(tint_lang_spirv_writer lib
lang/spirv/writer/output.cc
lang/spirv/writer/output.h
lang/spirv/writer/writer.cc
lang/spirv/writer/writer.h
)
Expand Down
2 changes: 0 additions & 2 deletions src/tint/lang/spirv/writer/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ if (tint_build_unittests || tint_build_benchmarks) {
if (tint_build_spv_writer) {
libtint_source_set("writer") {
sources = [
"output.cc",
"output.h",
"writer.cc",
"writer.h",
]
Expand Down
2 changes: 2 additions & 0 deletions src/tint/lang/spirv/writer/common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ cc_library(
"module.cc",
"operand.cc",
"option_helper.cc",
"output.cc",
],
hdrs = [
"binary_writer.h",
Expand All @@ -54,6 +55,7 @@ cc_library(
"operand.h",
"option_helpers.h",
"options.h",
"output.h",
],
deps = [
"//src/tint/api/common",
Expand Down
2 changes: 2 additions & 0 deletions src/tint/lang/spirv/writer/common/BUILD.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ tint_add_target(tint_lang_spirv_writer_common lib
lang/spirv/writer/common/option_helper.cc
lang/spirv/writer/common/option_helpers.h
lang/spirv/writer/common/options.h
lang/spirv/writer/common/output.cc
lang/spirv/writer/common/output.h
)

tint_target_add_dependencies(tint_lang_spirv_writer_common lib
Expand Down
2 changes: 2 additions & 0 deletions src/tint/lang/spirv/writer/common/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ if (tint_build_spv_writer) {
"option_helper.cc",
"option_helpers.h",
"options.h",
"output.cc",
"output.h",
]
deps = [
"${dawn_root}/src/utils:utils",
Expand Down
10 changes: 5 additions & 5 deletions src/tint/lang/spirv/writer/common/helper_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@
#include "src/tint/lang/core/ir/builder.h"
#include "src/tint/lang/core/ir/disassembler.h"
#include "src/tint/lang/core/ir/validator.h"
#include "src/tint/lang/core/type/array.h"
#include "src/tint/lang/core/type/depth_texture.h"
#include "src/tint/lang/core/type/matrix.h"
#include "src/tint/lang/core/type/multisampled_texture.h"
#include "src/tint/lang/core/type/sampled_texture.h"
#include "src/tint/lang/core/type/storage_texture.h"
#include "src/tint/lang/spirv/writer/common/spv_dump_test.h"
#include "src/tint/lang/spirv/writer/writer.h"
Expand Down Expand Up @@ -103,6 +98,9 @@ class SpirvWriterTestHelperBase : public BASE {
/// SPIR-V output.
std::string output_;

/// Workgroup info
std::optional<Output::WorkgroupInfo> workgroup_info;

/// @returns the error string from the validation
std::string Error() const { return err_; }

Expand All @@ -124,6 +122,8 @@ class SpirvWriterTestHelperBase : public BASE {
if (!Validate(result->spirv)) {
return false;
}
workgroup_info = result->workgroup_info;

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "src/tint/lang/spirv/writer/output.h"
#include "src/tint/lang/spirv/writer/common/output.h"

namespace tint::spirv::writer {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#ifndef SRC_TINT_LANG_SPIRV_WRITER_OUTPUT_H_
#define SRC_TINT_LANG_SPIRV_WRITER_OUTPUT_H_
#ifndef SRC_TINT_LANG_SPIRV_WRITER_COMMON_OUTPUT_H_
#define SRC_TINT_LANG_SPIRV_WRITER_COMMON_OUTPUT_H_

#include <cstdint>
#include <string>
#include <optional>
#include <vector>

namespace tint::spirv::writer {
Expand All @@ -49,10 +49,23 @@ struct Output {
/// @returns this
Output& operator=(const Output&);

/// Workgroup size information
struct WorkgroupInfo {
/// The x-component
uint32_t x;
/// The y-component
uint32_t y;
/// The z-component
uint32_t z;
};

/// The generated SPIR-V.
std::vector<uint32_t> spirv;

/// The workgroup size information, if the entry point was a compute shader
std::optional<WorkgroupInfo> workgroup_info = std::nullopt;
};

} // namespace tint::spirv::writer

#endif // SRC_TINT_LANG_SPIRV_WRITER_OUTPUT_H_
#endif // SRC_TINT_LANG_SPIRV_WRITER_COMMON_OUTPUT_H_
12 changes: 12 additions & 0 deletions src/tint/lang/spirv/writer/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ TEST_F(SpirvWriterTest, Function_Empty) {
OpReturn
OpFunctionEnd
)");
EXPECT_TRUE(workgroup_info.has_value());

// There is always an injected entry point if none exists in the source program.
EXPECT_EQ(workgroup_info->x, 1u);
EXPECT_EQ(workgroup_info->y, 1u);
EXPECT_EQ(workgroup_info->z, 1u);
}

// Test that we do not emit the same function type more than once.
Expand Down Expand Up @@ -113,6 +119,10 @@ TEST_F(SpirvWriterTest, Function_EntryPoint_Compute) {
OpReturn
OpFunctionEnd
)");
EXPECT_TRUE(workgroup_info.has_value());
EXPECT_EQ(workgroup_info->x, 32u);
EXPECT_EQ(workgroup_info->y, 4u);
EXPECT_EQ(workgroup_info->z, 1u);
}

TEST_F(SpirvWriterTest, Function_EntryPoint_Fragment) {
Expand All @@ -139,6 +149,7 @@ TEST_F(SpirvWriterTest, Function_EntryPoint_Fragment) {
OpReturn
OpFunctionEnd
)");
EXPECT_FALSE(workgroup_info.has_value());
}

TEST_F(SpirvWriterTest, Function_EntryPoint_Vertex) {
Expand Down Expand Up @@ -190,6 +201,7 @@ TEST_F(SpirvWriterTest, Function_EntryPoint_Vertex) {
OpReturn
OpFunctionEnd
)");
EXPECT_FALSE(workgroup_info.has_value());
}

TEST_F(SpirvWriterTest, Function_EntryPoint_Multiple) {
Expand Down
16 changes: 13 additions & 3 deletions src/tint/lang/spirv/writer/printer/printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class Printer {
}

/// @returns the generated SPIR-V code on success, or failure
Result<std::vector<uint32_t>> Code() {
Result<Output> Code() {
if (auto res = Generate(); res != Success) {
return res.Failure();
}
Expand All @@ -203,7 +203,11 @@ class Printer {
BinaryWriter writer;
writer.WriteHeader(module_.IdBound(), kWriterVersion);
writer.WriteModule(module_);
return std::move(writer.Result());

Output output;
output.spirv = std::move(writer.Result());
output.workgroup_info = workgroup_info;
return output;
}

private:
Expand Down Expand Up @@ -278,6 +282,9 @@ class Printer {

bool zero_init_workgroup_memory_ = false;

/// Workgroup information emitted
std::optional<Output::WorkgroupInfo> workgroup_info = std::nullopt;

/// Builds the SPIR-V from the IR
Result<SuccessType> Generate() {
auto valid = core::ir::ValidateAndDumpIfNeeded(ir_, "spirv.Printer");
Expand Down Expand Up @@ -768,6 +775,9 @@ class Printer {
auto const_wg_size = func->WorkgroupSizeAsConst();
TINT_ASSERT(const_wg_size);

// Store the workgroup information away to return from the generator.
workgroup_info = {(*const_wg_size)[0], (*const_wg_size)[1], (*const_wg_size)[2]};

module_.PushExecutionMode(
spv::Op::OpExecutionMode,
{id, U32Operand(SpvExecutionModeLocalSize), const_wg_size->at(0),
Expand Down Expand Up @@ -2539,7 +2549,7 @@ class Printer {

} // namespace

tint::Result<std::vector<uint32_t>> Print(core::ir::Module& module, const Options& options) {
tint::Result<Output> Print(core::ir::Module& module, const Options& options) {
return Printer{module, options}.Code();
}

Expand Down
7 changes: 2 additions & 5 deletions src/tint/lang/spirv/writer/printer/printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,8 @@
#ifndef SRC_TINT_LANG_SPIRV_WRITER_PRINTER_PRINTER_H_
#define SRC_TINT_LANG_SPIRV_WRITER_PRINTER_PRINTER_H_

#include <cstdint>
#include <vector>

#include "src/tint/lang/spirv/writer/common/module.h"
#include "src/tint/lang/spirv/writer/common/options.h"
#include "src/tint/lang/spirv/writer/common/output.h"
#include "src/tint/utils/result/result.h"

// Forward declarations
Expand All @@ -45,7 +42,7 @@ namespace tint::spirv::writer {
/// @returns the generated SPIR-V instructions on success, or failure
/// @param module the Tint IR module to generate
/// @param options the printer options
tint::Result<std::vector<uint32_t>> Print(core::ir::Module& module, const Options& options);
tint::Result<Output> Print(core::ir::Module& module, const Options& options);

} // namespace tint::spirv::writer

Expand Down
3 changes: 3 additions & 0 deletions src/tint/lang/spirv/writer/texture_builtin_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
#include "src/tint/lang/core/builtin_fn.h"
#include "src/tint/lang/core/fluent_types.h"
#include "src/tint/lang/core/type/depth_multisampled_texture.h"
#include "src/tint/lang/core/type/depth_texture.h"
#include "src/tint/lang/core/type/multisampled_texture.h"
#include "src/tint/lang/core/type/sampled_texture.h"
#include "src/tint/lang/spirv/writer/common/helper_test.h"

using namespace tint::core::number_suffixes; // NOLINT
Expand Down
12 changes: 1 addition & 11 deletions src/tint/lang/spirv/writer/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

#include "src/tint/lang/spirv/writer/writer.h"

#include <memory>
#include <utility>

#include "src/tint/lang/spirv/writer/common/option_helpers.h"
Expand All @@ -47,21 +46,12 @@ Result<Output> Generate(core::ir::Module& ir, const Options& options) {
}
}

Output output;

// Raise from core-dialect to SPIR-V-dialect.
if (auto res = Raise(ir, options); res != Success) {
return std::move(res.Failure());
}

// Generate the SPIR-V code.
auto spirv = Print(ir, options);
if (spirv != Success) {
return std::move(spirv.Failure());
}
output.spirv = std::move(spirv.Get());

return output;
return Print(ir, options);
}

} // namespace tint::spirv::writer
2 changes: 1 addition & 1 deletion src/tint/lang/spirv/writer/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

#include "src/tint/lang/core/ir/module.h"
#include "src/tint/lang/spirv/writer/common/options.h"
#include "src/tint/lang/spirv/writer/output.h"
#include "src/tint/lang/spirv/writer/common/output.h"
#include "src/tint/utils/result/result.h"

namespace tint::spirv::writer {
Expand Down

0 comments on commit 6f3bd8a

Please sign in to comment.