Skip to content

Commit

Permalink
Add D3D12 Backend for MultiDrawIndirect Feature
Browse files Browse the repository at this point in the history
This CL adds support for the D3D12 backend, which requires emulation of
base vertex and base instance parameters in the indirect arguments.
The indirect draw validation compute pass duplicates the parameters if it is used in the shaders.
MultiDraw capability is supported on all devices with D3D12.
New tests added for MultiDrawIndexedIndirect to test baseVertex and firstInstance.

Change-Id: I75bc48243e4801f49e6e50091cf1560593a3d14c
Bug: 356461286
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/203254
Reviewed-by: Loko Kung <[email protected]>
Commit-Queue: Srijan Dhungana <[email protected]>
Reviewed-by: Austin Eng <[email protected]>
  • Loading branch information
Sirtsu55 authored and Dawn LUCI CQ committed Aug 21, 2024
1 parent 17f3173 commit 3870081
Show file tree
Hide file tree
Showing 8 changed files with 343 additions and 38 deletions.
6 changes: 5 additions & 1 deletion src/dawn/native/IndirectDrawMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,23 +242,27 @@ void IndirectDrawMetadata::ClearIndexedIndirectBufferValidationInfo() {
mIndexedIndirectBufferValidationInfo.clear();
}

void IndirectDrawMetadata::AddMultiDrawIndirect(MultiDrawIndirectCmd* cmd) {
void IndirectDrawMetadata::AddMultiDrawIndirect(bool duplicateBaseVertexInstance,
MultiDrawIndirectCmd* cmd) {
IndirectMultiDraw multiDraw;
multiDraw.type = DrawType::NonIndexed;
multiDraw.cmd = cmd;
multiDraw.duplicateBaseVertexInstance = duplicateBaseVertexInstance;
mMultiDraws.push_back(multiDraw);
}

void IndirectDrawMetadata::AddMultiDrawIndexedIndirect(BufferBase* indexBuffer,
wgpu::IndexFormat indexFormat,
uint64_t indexBufferSize,
uint64_t indexBufferOffset,
bool duplicateBaseVertexInstance,
MultiDrawIndexedIndirectCmd* cmd) {
IndirectMultiDraw multiDraw;
multiDraw.type = DrawType::Indexed;
multiDraw.cmd = cmd;
multiDraw.indexBufferSize = indexBufferSize;
multiDraw.indexFormat = indexFormat;
multiDraw.duplicateBaseVertexInstance = duplicateBaseVertexInstance;

mMultiDraws.push_back(multiDraw);
}
Expand Down
8 changes: 5 additions & 3 deletions src/dawn/native/IndirectDrawMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ class IndirectDrawMetadata : public NonCopyable {
struct IndirectMultiDraw {
DrawType type;

uint64_t indexBufferSize;
wgpu::IndexFormat indexFormat;
uint64_t indexBufferSize = 0;
wgpu::IndexFormat indexFormat = wgpu::IndexFormat::Undefined;
bool duplicateBaseVertexInstance;

// When validation is enabled, the original indirect buffer is validated and copied to a new
// indirect buffer containing only valid commands. The pointer to the command allocated in
Expand Down Expand Up @@ -166,12 +167,13 @@ class IndirectDrawMetadata : public NonCopyable {
bool duplicateBaseVertexInstance,
DrawIndirectCmd* cmd);

void AddMultiDrawIndirect(MultiDrawIndirectCmd* cmd);
void AddMultiDrawIndirect(bool duplicateBaseVertexInstance, MultiDrawIndirectCmd* cmd);

void AddMultiDrawIndexedIndirect(BufferBase* indexBuffer,
wgpu::IndexFormat indexFormat,
uint64_t indexBufferSize,
uint64_t indexBufferOffset,
bool duplicateBaseVertexInstance,
MultiDrawIndexedIndirectCmd* cmd);

void ClearIndexedIndirectBufferValidationInfo();
Expand Down
115 changes: 87 additions & 28 deletions src/dawn/native/IndirectDrawValidationEncoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ static const char sRenderValidationShaderSource[] = R"(
let inIndex = drawIndex * numInputParams;
let inputOffset = drawConstants.indirectOffsetInElements;
if (bool(drawConstants.flags & kDuplicateBaseVertexInstance)) {
// first/baseVertex and firstInstance are always last two parameters
let dupIndex = inputOffset + inIndex + numInputParams - 2u;
outputParams.data[outIndex] = inputParams.data[dupIndex];
outputParams.data[outIndex + 1u] = inputParams.data[dupIndex + 1u];
outIndex = outIndex + 2u;
}
for(var i = 0u; i < numInputParams; i = i + 1u) {
outputParams.data[outIndex + i] = inputParams.data[inputOffset + inIndex + i];
}
Expand Down Expand Up @@ -266,9 +275,7 @@ static const char sRenderValidationShaderSource[] = R"(
@compute @workgroup_size(kWorkgroupSize, 1, 1)
fn validate_multi_draw(@builtin(global_invocation_id) id : vec3u) {
var drawCount = drawConstants.maxDrawCount;
var drawCountOffset = drawConstants.drawCountOffsetInElements;
if(bool(drawConstants.flags & kIndirectDrawCountBuffer)) {
Expand All @@ -280,6 +287,16 @@ static const char sRenderValidationShaderSource[] = R"(
return;
}
if(!bool(drawConstants.flags & kValidationEnabled)) {
set_pass_multi(id.x);
return;
}
if (!bool(drawConstants.flags & kIndexedDraw)) {
set_pass_multi(id.x);
return;
}
let numIndexBufferElementsHigh = drawConstants.numIndexBufferElementsHigh;
if (numIndexBufferElementsHigh >= 2u) {
Expand Down Expand Up @@ -313,6 +330,17 @@ static const char sRenderValidationShaderSource[] = R"(
)";

static constexpr uint32_t GetOutputIndirectDrawSize(IndirectDrawMetadata::DrawType drawType,
bool duplicateBaseVertexInstance) {
uint32_t drawSize = drawType == IndirectDrawMetadata::DrawType::Indexed
? kDrawIndexedIndirectSize
: kDrawIndirectSize;
if (duplicateBaseVertexInstance) {
drawSize += 2 * sizeof(uint32_t);
}
return drawSize;
}

ResultOrError<dawn::Ref<ComputePipelineBase>> CreateRenderValidationPipelines(
DeviceBase* device,
const char* entryPoint,
Expand Down Expand Up @@ -473,10 +501,8 @@ MaybeError EncodeIndirectDrawValidationCommands(DeviceBase* device,
config.drawType == IndirectDrawMetadata::DrawType::Indexed ? kDrawIndexedIndirectSize
: kDrawIndirectSize;

uint64_t outputIndirectSize = indirectDrawCommandSize;
if (config.duplicateBaseVertexInstance) {
outputIndirectSize += 2 * sizeof(uint32_t);
}
uint64_t outputIndirectSize =
GetOutputIndirectDrawSize(config.drawType, config.duplicateBaseVertexInstance);

for (const IndirectDrawMetadata::IndirectValidationBatch& batch :
validationInfo.GetBatches()) {
Expand Down Expand Up @@ -552,11 +578,22 @@ MaybeError EncodeIndirectDrawValidationCommands(DeviceBase* device,
uint64_t outputParamsSizeForMultiDraw = 0;
// Calculate size of output params for multi draws
for (auto& draw : multiDraws) {
// Don't need to validate non-indexed draws.
if (draw.type == IndirectDrawMetadata::DrawType::NonIndexed) {
// Multi draw metadatas are added even if validation is disabled, because the Metal backend
// needs to convert all multi draws into an ICB. If validation is disabled, and the draw
// doesn't need duplication of base vertex and instance, we can skip the compute pass.
// In general, non-indexed multi draws don't need validation.
if ((draw.type == IndirectDrawMetadata::DrawType::NonIndexed ||
!device->IsValidationEnabled()) &&
!draw.duplicateBaseVertexInstance) {
continue;
}
outputParamsSizeForMultiDraw += draw.cmd->maxDrawCount * kDrawIndexedIndirectSize;

outputParamsSizeForMultiDraw +=
draw.cmd->maxDrawCount *
GetOutputIndirectDrawSize(draw.type, draw.duplicateBaseVertexInstance);

outputParamsSizeForMultiDraw =
Align(outputParamsSizeForMultiDraw, minStorageBufferOffsetAlignment);

if (outputParamsSizeForMultiDraw > maxStorageBufferBindingSize) {
return DAWN_INTERNAL_ERROR("Too many multiDrawIndexedIndirect calls to validate");
Expand All @@ -580,7 +617,8 @@ MaybeError EncodeIndirectDrawValidationCommands(DeviceBase* device,
for (const Pass& pass : passes) {
requiredBatchDataBufferSize = std::max(requiredBatchDataBufferSize, pass.batchDataSize);
}
// Needs to at least be able to store a MultiDrawConstants struct for the multi draw validation.
// Needs to at least be able to store a MultiDrawConstants struct for the multi draw
// validation.
requiredBatchDataBufferSize =
std::max(requiredBatchDataBufferSize, static_cast<uint64_t>(sizeof(MultiDrawConstants)));

Expand Down Expand Up @@ -656,10 +694,10 @@ MaybeError EncodeIndirectDrawValidationCommands(DeviceBase* device,
bindGroupDescriptor.entryCount = 3;
bindGroupDescriptor.entries = bindings;

// Finally, we can now encode our validation and duplication passes. Each pass first does
// two WriteBuffer to get batch and pass data over to the GPU, followed by a single compute
// pass. The compute pass encodes a separate SetBindGroup and Dispatch command for each
// batch.
// Finally, we can now encode our validation and duplication passes. Each pass first
// does a WriteBuffer to get batch and pass data over to the GPU, followed by a single
// compute pass. The compute pass encodes a separate SetBindGroup and Dispatch command
// for each batch.
for (const Pass& pass : passes) {
commandEncoder->APIWriteBuffer(batchDataBuffer.GetBuffer(), 0,
static_cast<const uint8_t*>(pass.batchData.get()),
Expand Down Expand Up @@ -724,12 +762,19 @@ MaybeError EncodeIndirectDrawValidationCommands(DeviceBase* device,
uint64_t outputOffset = multiDrawOutputParamsOffset;

for (auto& draw : multiDraws) {
if (draw.type == IndirectDrawMetadata::DrawType::NonIndexed) {
// If the draw meets these conditions, there is no need to run the compute pass,
// and there is no space allocated for the output params
if ((draw.type == IndirectDrawMetadata::DrawType::NonIndexed ||
!device->IsValidationEnabled()) &&
!draw.duplicateBaseVertexInstance) {
continue;
}

const size_t formatSize = IndexFormatSize(draw.indexFormat);
uint64_t numIndexBufferElements = draw.indexBufferSize / formatSize;
uint64_t numIndexBufferElements = 0;
if (draw.type == IndirectDrawMetadata::DrawType::Indexed) {
const size_t formatSize = IndexFormatSize(draw.indexFormat);
numIndexBufferElements = draw.indexBufferSize / formatSize;
}

// Same struct for both indexed and non-indexed draws.
MultiDrawIndirectCmd* cmd = draw.cmd;
Expand All @@ -748,10 +793,20 @@ MaybeError EncodeIndirectDrawValidationCommands(DeviceBase* device,
static_cast<uint32_t>(numIndexBufferElements & 0xFFFFFFFF);
drawConstants.numIndexBufferElementsHigh =
static_cast<uint32_t>((numIndexBufferElements >> 32) & 0xFFFFFFFF);
drawConstants.flags = kIndexedDraw;

drawConstants.flags = 0;
if (device->IsValidationEnabled()) {
drawConstants.flags |= kValidationEnabled;
}
if (draw.type == IndirectDrawMetadata::DrawType::Indexed) {
drawConstants.flags |= kIndexedDraw;
}
if (cmd->drawCountBuffer != nullptr) {
drawConstants.flags |= kIndirectDrawCountBuffer;
}
if (draw.duplicateBaseVertexInstance) {
drawConstants.flags |= kDuplicateBaseVertexInstance;
}

inputIndirectBinding.buffer = cmd->indirectBuffer.Get();
// We can't use the offset directly because the indirect offset is guaranteed to
Expand All @@ -763,19 +818,23 @@ MaybeError EncodeIndirectDrawValidationCommands(DeviceBase* device,

outputParamsBinding.buffer = outputParamsBuffer.GetBuffer();
outputParamsBinding.offset = outputOffset;
outputParamsBinding.size =
draw.cmd->maxDrawCount *
GetOutputIndirectDrawSize(draw.type, draw.duplicateBaseVertexInstance);

if (cmd->drawCountBuffer != nullptr) {
// If the drawCountBuffer is set, we need to bind it to the bind group.
// The drawCountBuffer is used to read the drawCount for the multi draw call.
// If the drawCount exceeds the maxDrawCount, it will be clamped to maxDrawCount.
// If the drawCount exceeds the maxDrawCount, it will be clamped to
// maxDrawCount.
drawCountBinding.buffer = cmd->drawCountBuffer.Get();
drawCountBinding.offset =
AlignDown(cmd->drawCountOffset, minStorageBufferOffsetAlignment);
} else {
// This is an unused binding.
// Bind group entry for the drawCountBuffer is not needed however we need to bind
// something else than nullptr to the bind group entry to avoid validation errors.
// This buffer is never used in the shader, since there is a flag
// Bind group entry for the drawCountBuffer is not needed however we need to
// bind something else than nullptr to the bind group entry to avoid validation
// errors. This buffer is never used in the shader, since there is a flag
// (kIndirectDrawCountBuffer) to check if the drawCountBuffer is set.
drawCountBinding.buffer = cmd->indirectBuffer.Get();
drawCountBinding.offset = 0;
Expand All @@ -792,22 +851,22 @@ MaybeError EncodeIndirectDrawValidationCommands(DeviceBase* device,
passEncoder->APISetPipeline(pipeline);
passEncoder->APISetBindGroup(0, bindGroup.Get());

// TODO(crbug.com/356461286): After maxDrawCount has a limit we can
// dispatch exact number of workgroups without worrying about overflow:
// uint32_t workgroupCount = (cmd->maxDrawCount + kWorkgroupSize - 1u) / kWorkgroupSize;
uint32_t workgroupCount = cmd->maxDrawCount / kWorkgroupSize;
// Integer division rounds down so adding 1 if there is a remainder.
workgroupCount += cmd->maxDrawCount % kWorkgroupSize == 0 ? 0 : 1;
passEncoder->APIDispatchWorkgroups(workgroupCount);
passEncoder->APIEnd();

// Update the draw command to use the validated indirect buffer.
// The drawCountBuffer doesn't need to be updated because if it exceeds the maxDrawCount
// it will be clamped to maxDrawCount.
// The drawCountBuffer doesn't need to be updated because if it exceeds the
// maxDrawCount it will be clamped to maxDrawCount.
cmd->indirectBuffer = outputParamsBuffer.GetBuffer();
cmd->indirectOffset = outputOffset;

outputOffset += cmd->maxDrawCount * kDrawIndexedIndirectSize;
// Proceed to the next output offset.
outputOffset += cmd->maxDrawCount *
GetOutputIndirectDrawSize(draw.type, draw.duplicateBaseVertexInstance);
outputOffset = Align(outputOffset, minStorageBufferOffsetAlignment);
}
}

Expand Down
12 changes: 10 additions & 2 deletions src/dawn/native/RenderEncoderBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,11 @@ void RenderEncoderBase::APIMultiDrawIndirect(BufferBase* indirectBuffer,
cmd->drawCountBuffer = drawCountBuffer;
cmd->drawCountOffset = drawCountBufferOffset;

mIndirectDrawMetadata.AddMultiDrawIndirect(cmd);
bool duplicateBaseVertexInstance =
GetDevice()->ShouldDuplicateParametersForDrawIndirect(
mCommandBufferState.GetRenderPipeline());

mIndirectDrawMetadata.AddMultiDrawIndirect(duplicateBaseVertexInstance, cmd);

// TODO(crbug.com/dawn/1166): Adding the indirectBuffer is needed for correct usage
// validation, but it will unecessarily transition to indirectBuffer usage in the
Expand Down Expand Up @@ -470,10 +474,14 @@ void RenderEncoderBase::APIMultiDrawIndexedIndirect(BufferBase* indirectBuffer,
cmd->drawCountBuffer = drawCountBuffer;
cmd->drawCountOffset = drawCountBufferOffset;

bool duplicateBaseVertexInstance =
GetDevice()->ShouldDuplicateParametersForDrawIndirect(
mCommandBufferState.GetRenderPipeline());

mIndirectDrawMetadata.AddMultiDrawIndexedIndirect(
mCommandBufferState.GetIndexBuffer(), mCommandBufferState.GetIndexFormat(),
mCommandBufferState.GetIndexBufferSize(),
mCommandBufferState.GetIndexBufferOffset(), cmd);
mCommandBufferState.GetIndexBufferOffset(), duplicateBaseVertexInstance, cmd);

// TODO(crbug.com/dawn/1166): Adding the indirectBuffer is needed for correct usage
// validation, but it will unecessarily transition to indirectBuffer usage in the
Expand Down
50 changes: 50 additions & 0 deletions src/dawn/native/d3d12/CommandBufferD3D12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,56 @@ MaybeError CommandBuffer::RecordRenderPass(CommandRecordingContext* commandConte
break;
}

case Command::MultiDrawIndirect: {
MultiDrawIndirectCmd* draw = iter->NextCommand<MultiDrawIndirectCmd>();

DAWN_TRY(bindingTracker->Apply(commandContext));
vertexBufferTracker.Apply(commandList, lastPipeline);

Buffer* indirectBuffer = ToBackend(draw->indirectBuffer.Get());
DAWN_ASSERT(indirectBuffer != nullptr);

Buffer* countBuffer = ToBackend(draw->drawCountBuffer.Get());

// There is no distinction between DrawIndirect and MultiDrawIndirect in D3D12.
// This is why we can use the same command signature for both.
ComPtr<ID3D12CommandSignature> signature =
lastPipeline->GetDrawIndirectCommandSignature();

commandList->ExecuteIndirect(
signature.Get(), draw->maxDrawCount, indirectBuffer->GetD3D12Resource(),
draw->indirectOffset,
countBuffer != nullptr ? countBuffer->GetD3D12Resource() : nullptr,
countBuffer != nullptr ? draw->drawCountOffset : 0);

break;
}

case Command::MultiDrawIndexedIndirect: {
MultiDrawIndexedIndirectCmd* draw =
iter->NextCommand<MultiDrawIndexedIndirectCmd>();

DAWN_TRY(bindingTracker->Apply(commandContext));
vertexBufferTracker.Apply(commandList, lastPipeline);

Buffer* indirectBuffer = ToBackend(draw->indirectBuffer.Get());
DAWN_ASSERT(indirectBuffer != nullptr);

Buffer* countBuffer = ToBackend(draw->drawCountBuffer.Get());

// There is no distinction between DrawIndexedIndirect and MultiDrawIndexedIndirect
// in D3D12. This is why we can use the same command signature for both.
ComPtr<ID3D12CommandSignature> signature =
lastPipeline->GetDrawIndexedIndirectCommandSignature();

commandList->ExecuteIndirect(
signature.Get(), draw->maxDrawCount, indirectBuffer->GetD3D12Resource(),
draw->indirectOffset,
countBuffer != nullptr ? countBuffer->GetD3D12Resource() : nullptr,
countBuffer != nullptr ? draw->drawCountOffset : 0);
break;
}

case Command::InsertDebugMarker: {
InsertDebugMarkerCmd* cmd = iter->NextCommand<InsertDebugMarkerCmd>();
const char* label = iter->NextData<char>(cmd->length + 1);
Expand Down
1 change: 1 addition & 0 deletions src/dawn/native/d3d12/PhysicalDeviceD3D12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ void PhysicalDevice::InitializeSupportedFeaturesImpl() {
EnableFeature(Feature::SharedBufferMemoryD3D12Resource);
EnableFeature(Feature::ShaderModuleCompilationOptions);
EnableFeature(Feature::StaticSamplers);
EnableFeature(Feature::MultiDrawIndirect);

if (AreTimestampQueriesSupported()) {
EnableFeature(Feature::TimestampQuery);
Expand Down
Loading

0 comments on commit 3870081

Please sign in to comment.