diff --git a/lib/DxilContainer/DxcContainerBuilder.cpp b/lib/DxilContainer/DxcContainerBuilder.cpp index 3c10b0e70a..770aa910a4 100644 --- a/lib/DxilContainer/DxcContainerBuilder.cpp +++ b/lib/DxilContainer/DxcContainerBuilder.cpp @@ -104,7 +104,11 @@ HRESULT STDMETHODCALLTYPE DxcContainerBuilder::RemovePart(UINT32 fourCC) { HRESULT STDMETHODCALLTYPE DxcContainerBuilder::SerializeContainer(IDxcOperationResult **ppResult) { + if (ppResult == nullptr) + return E_INVALIDARG; + DxcThreadMalloc TM(m_pMalloc); + try { // Allocate memory for new dxil container. uint32_t ContainerSize = ComputeContainerSize(); @@ -161,6 +165,11 @@ DxcContainerBuilder::SerializeContainer(IDxcOperationResult **ppResult) { errorHeap.Detach(); } + // Add Hash. + if (SUCCEEDED(valHR)) + HashAndUpdate(IsDxilContainerLike(pResult->GetBufferPointer(), + pResult->GetBufferSize())); + IFT(DxcResult::Create( valHR, DXC_OUT_OBJECT, {DxcOutputObject::DataOutput(DXC_OUT_OBJECT, pResult, DxcOutNoName), @@ -169,21 +178,6 @@ DxcContainerBuilder::SerializeContainer(IDxcOperationResult **ppResult) { } CATCH_CPP_RETURN_HRESULT(); - if (ppResult == nullptr || *ppResult == nullptr) - return S_OK; - - HRESULT HR; - (*ppResult)->GetStatus(&HR); - if (FAILED(HR)) - return HR; - - CComPtr pObject; - IFR((*ppResult)->GetResult(&pObject)); - - // Add Hash. - LPVOID PTR = pObject->GetBufferPointer(); - if (IsDxilContainerLike(PTR, pObject->GetBufferSize())) - HashAndUpdate((DxilContainerHeader *)PTR); return S_OK; } diff --git a/tools/clang/tools/dxclib/dxc.cpp b/tools/clang/tools/dxclib/dxc.cpp index 1bcf5d8e3f..cdcfe2b3f6 100644 --- a/tools/clang/tools/dxclib/dxc.cpp +++ b/tools/clang/tools/dxclib/dxc.cpp @@ -644,7 +644,7 @@ int DxcContext::VerifyRootSignature() { IFT(pContainerBuilder->AddPart(hlsl::DxilFourCC::DFCC_RootSignature, pRootSignature)); CComPtr pOperationResult; - pContainerBuilder->SerializeContainer(&pOperationResult); + IFT(pContainerBuilder->SerializeContainer(&pOperationResult)); HRESULT status = E_FAIL; CComPtr pResult; IFT(pOperationResult->GetStatus(&status)); diff --git a/tools/clang/unittests/HLSL/ValidationTest.cpp b/tools/clang/unittests/HLSL/ValidationTest.cpp index 8008541bfa..19696de022 100644 --- a/tools/clang/unittests/HLSL/ValidationTest.cpp +++ b/tools/clang/unittests/HLSL/ValidationTest.cpp @@ -206,6 +206,7 @@ class ValidationTest : public ::testing::Test { TEST_METHOD(SimpleGs1Fail) TEST_METHOD(UavBarrierFail) TEST_METHOD(UndefValueFail) + TEST_METHOD(ValidationFailNoHash) TEST_METHOD(UpdateCounterFail) TEST_METHOD(LocalResCopy) TEST_METHOD(ResCounter) @@ -1189,6 +1190,60 @@ TEST_F(ValidationTest, UavBarrierFail) { TEST_F(ValidationTest, UndefValueFail) { TestCheck(L"..\\CodeGenHLSL\\UndefValue.hlsl"); } +// verify that containers that are not valid DXIL do not +// get assigned a hash. +TEST_F(ValidationTest, ValidationFailNoHash) { + if (m_ver.SkipDxilVersion(1, 8)) + return; + CComPtr pProgram; + + // We need any shader that will pass compilation but fail validation. + // This shader reads from uninitialized 'float a', which works for now. + LPCSTR pSource = R"( + float main(snorm float b : B) : SV_DEPTH + { + float a; + return b + a; + } +)"; + + CComPtr pSourceBlob; + Utf8ToBlob(m_dllSupport, pSource, &pSourceBlob); + std::vector pArguments = {L"-Vd"}; + LPCSTR pShaderModel = "ps_6_0"; + bool result = CompileSource(pSourceBlob, pShaderModel, pArguments.data(), 1, + nullptr, 0, &pProgram); + + VERIFY_IS_TRUE(result); + + CComPtr pValidator; + CComPtr pResult; + unsigned Flags = 0; + VERIFY_SUCCEEDED( + m_dllSupport.CreateInstance(CLSID_DxcValidator, &pValidator)); + + VERIFY_SUCCEEDED(pValidator->Validate(pProgram, Flags, &pResult)); + HRESULT status; + VERIFY_IS_NOT_NULL(pResult); + CComPtr pValidationOutput; + pResult->GetStatus(&status); + + // expect validation to fail + VERIFY_FAILED(status); + pResult->GetResult(&pValidationOutput); + // Make sure the validation output is not null even when validation fails + VERIFY_SUCCEEDED(pValidationOutput != nullptr); + + hlsl::DxilContainerHeader *pHeader = IsDxilContainerLike( + pProgram->GetBufferPointer(), pProgram->GetBufferSize()); + VERIFY_IS_NOT_NULL(pHeader); + + BYTE ZeroHash[DxilContainerHashSize] = {0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + + // Should be equal, this proves the hash isn't written when validation fails + VERIFY_ARE_EQUAL(memcmp(ZeroHash, pHeader->Hash.Digest, sizeof(ZeroHash)), 0); +} TEST_F(ValidationTest, UpdateCounterFail) { if (m_ver.SkipIRSensitiveTest()) return;