From 429a59c941988a7ba9f261d1946b9972ee79eca0 Mon Sep 17 00:00:00 2001
From: krohmerNV <42233792+krohmerNV@users.noreply.github.com>
Date: Mon, 20 Mar 2023 02:25:21 +0100
Subject: [PATCH] MDL workaround for structures with material fields (#1274)
The test cases in TestSuite/pbrlib/multioutput/multishaderoutput.mtlx generated invalid structs in MDL. This PR creates a separate function for each struct member, if at least one struct member is a material.
---
.../pbrlib/multioutput/multishaderoutput.mtlx | 31 +++++++
source/MaterialXGenMdl/MdlShaderGenerator.cpp | 10 ++-
.../Nodes/ClosureCompoundNodeMdl.cpp | 68 ++++++++++++++-
.../MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp | 87 +++++++++++++++++--
.../MaterialXGenMdl/Nodes/CompoundNodeMdl.h | 19 ++++
5 files changed, 202 insertions(+), 13 deletions(-)
create mode 100644 resources/Materials/TestSuite/pbrlib/multioutput/multishaderoutput.mtlx
diff --git a/resources/Materials/TestSuite/pbrlib/multioutput/multishaderoutput.mtlx b/resources/Materials/TestSuite/pbrlib/multioutput/multishaderoutput.mtlx
new file mode 100644
index 0000000000..2b1f951e84
--- /dev/null
+++ b/resources/Materials/TestSuite/pbrlib/multioutput/multishaderoutput.mtlx
@@ -0,0 +1,31 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/source/MaterialXGenMdl/MdlShaderGenerator.cpp b/source/MaterialXGenMdl/MdlShaderGenerator.cpp
index aa32682866..0f5b01ce95 100644
--- a/source/MaterialXGenMdl/MdlShaderGenerator.cpp
+++ b/source/MaterialXGenMdl/MdlShaderGenerator.cpp
@@ -382,7 +382,15 @@ string MdlShaderGenerator::getUpstreamResult(const ShaderInput* input, GenContex
const ShaderNode* upstreamNode = upstreamOutput->getNode();
if (upstreamNode->numOutputs() > 1)
{
- variable = upstreamNode->getName() + "_result.mxp_" + upstreamOutput->getName();
+ const CompoundNodeMdl* upstreamNodeMdl = dynamic_cast(&upstreamNode->getImplementation());
+ if (upstreamNodeMdl && upstreamNodeMdl->unrollReturnStructMembers())
+ {
+ variable = upstreamNode->getName() + "__" + upstreamOutput->getName();
+ }
+ else
+ {
+ variable = upstreamNode->getName() + "_result.mxp_" + upstreamOutput->getName();
+ }
}
else
{
diff --git a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp
index d5d2a3e533..e175696d55 100644
--- a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp
+++ b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp
@@ -30,12 +30,74 @@ void ClosureCompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenC
{
const ShaderGenerator& shadergen = context.getShaderGenerator();
- const bool isMaterialExpr = (_rootGraph->hasClassification(ShaderNode::Classification::CLOSURE) ||
- _rootGraph->hasClassification(ShaderNode::Classification::SHADER));
-
// Emit functions for all child nodes
shadergen.emitFunctionDefinitions(*_rootGraph, context, stage);
+ // split all fields into separate functions
+ if (!_returnStruct.empty() && _unrollReturnStructMembers)
+ {
+ // make sure the upstream definitions are known
+ for (const ShaderGraphOutputSocket* outputSocket : _rootGraph->getOutputSockets())
+ {
+ if (!outputSocket->getConnection())
+ continue;
+
+ const ShaderNode* upstream = outputSocket->getConnection()->getNode();
+ const bool isMaterialExpr = (upstream->hasClassification(ShaderNode::Classification::CLOSURE) ||
+ upstream->hasClassification(ShaderNode::Classification::SHADER));
+
+ // since the emit fuctions are const, the field name to generate a function for is passed via context
+ const std::string& fieldName = outputSocket->getName();
+ GenUserDataStringPtr fieldNamePtr = std::make_shared(fieldName);
+ context.pushUserData(CompoundNodeMdl::GEN_USER_DATA_RETURN_STRUCT_FIELD_NAME, fieldNamePtr);
+
+ // Emit function signature.
+ shadergen.emitComment("unrolled structure field: " + _returnStruct + "." + fieldName + " (name=\"" + node.getName() + "\")", stage);
+ emitFunctionSignature(node, context, stage);
+
+ // Special case for material expresions.
+ if (isMaterialExpr)
+ {
+ shadergen.emitLine(" = let", stage, false);
+ }
+
+ // Function body.
+ shadergen.emitScopeBegin(stage);
+
+ // Emit all texturing nodes. These are inputs to the
+ // closure nodes and need to be emitted first.
+ shadergen.emitFunctionCalls(*_rootGraph, context, stage, ShaderNode::Classification::TEXTURE);
+
+ // Emit function calls for internal closures nodes connected to the graph sockets.
+ // These will in turn emit function calls for any dependent closure nodes upstream.
+ if (upstream->getParent() == _rootGraph.get() &&
+ (upstream->hasClassification(ShaderNode::Classification::CLOSURE) || upstream->hasClassification(ShaderNode::Classification::SHADER)))
+ {
+ shadergen.emitFunctionCall(*upstream, context, stage);
+ }
+
+ // Emit final results
+ if (isMaterialExpr)
+ {
+ shadergen.emitScopeEnd(stage);
+ const string result = shadergen.getUpstreamResult(outputSocket, context);
+ shadergen.emitLine("in material(" + result + ")", stage);
+ }
+ else
+ {
+ const string result = shadergen.getUpstreamResult(outputSocket, context);
+ shadergen.emitLine("return " + result, stage);
+ }
+ shadergen.emitLineBreak(stage);
+
+ context.popUserData(CompoundNodeMdl::GEN_USER_DATA_RETURN_STRUCT_FIELD_NAME);
+ }
+ return;
+ }
+
+ const bool isMaterialExpr = (_rootGraph->hasClassification(ShaderNode::Classification::CLOSURE) ||
+ _rootGraph->hasClassification(ShaderNode::Classification::SHADER));
+
// Emit function signature.
emitFunctionSignature(node, context, stage);
diff --git a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp
index df6d61d7f8..03c8bc172f 100644
--- a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp
+++ b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp
@@ -14,6 +14,8 @@
MATERIALX_NAMESPACE_BEGIN
+const string CompoundNodeMdl::GEN_USER_DATA_RETURN_STRUCT_FIELD_NAME = "returnStructFieldName";
+
ShaderNodeImplPtr CompoundNodeMdl::create()
{
return std::make_shared();
@@ -28,6 +30,15 @@ void CompoundNodeMdl::initialize(const InterfaceElement& element, GenContext& co
{
_returnStruct = _functionName + "__result";
}
+
+ // Materials can not be members of structs. Identify this case in order to handle it.
+ for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets())
+ {
+ if (output->getType()->getSemantic() == TypeDesc::SEMANTIC_SHADER)
+ {
+ _unrollReturnStructMembers = true;
+ }
+ }
}
void CompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const
@@ -94,10 +105,48 @@ void CompoundNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& conte
DEFINE_SHADER_STAGE(stage, Stage::PIXEL)
{
const ShaderGenerator& shadergen = context.getShaderGenerator();
+ const Syntax& syntax = shadergen.getSyntax();
// Begin function call.
if (!_returnStruct.empty())
{
+ // when unrolling structure members, the call that creates the struct needs to skipped
+ if (_unrollReturnStructMembers)
+ {
+ // make sure the upstream definitions are known
+ shadergen.emitComment("fill unrolled structure fields: " + _returnStruct + " (name=\"" + node.getName() + "\")", stage);
+ for (const ShaderGraphOutputSocket* outputSocket : _rootGraph->getOutputSockets())
+ {
+ if (!outputSocket->getConnection())
+ continue;
+
+ const std::string& fieldName = outputSocket->getName();
+
+ // Emit the struct field.
+ const string& outputType = syntax.getTypeName(outputSocket->getType());
+ const string resultVariableName = node.getName() + "__" + fieldName;
+
+ shadergen.emitLineBegin(stage);
+ shadergen.emitString(outputType + " " + resultVariableName + " = ", stage);
+ shadergen.emitString(_functionName + "__" + fieldName + "(", stage);
+
+ // Emit inputs.
+ string delim = "";
+ for (ShaderInput* input : node.getInputs())
+ {
+ shadergen.emitString(delim, stage);
+ shadergen.emitInput(input, context, stage);
+ delim = ", ";
+ }
+
+ // End function call
+ shadergen.emitString(")", stage);
+ shadergen.emitLineEnd(stage);
+ }
+
+ return;
+ }
+
// Emit the struct multioutput.
const string resultVariableName = node.getName() + "_result";
shadergen.emitLineBegin(stage);
@@ -135,18 +184,38 @@ void CompoundNodeMdl::emitFunctionSignature(const ShaderNode&, GenContext& conte
if (!_returnStruct.empty())
{
- // Define the output struct.
- shadergen.emitLine("struct " + _returnStruct, stage, false);
- shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS);
- for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets())
+ if (_unrollReturnStructMembers)
{
- shadergen.emitLine(syntax.getTypeName(output->getType()) + " mxp_" + output->getName(), stage);
+ const auto fieldName = context.getUserData(GEN_USER_DATA_RETURN_STRUCT_FIELD_NAME);
+
+ if (fieldName)
+ {
+ // Begin function signature.
+ const ShaderGraphOutputSocket* outputSocket = _rootGraph->getOutputSocket(fieldName->getValue());
+ const string& outputType = syntax.getTypeName(outputSocket->getType());
+ shadergen.emitLine(outputType + " " + _functionName + "__" + fieldName->getValue(), stage, false);
+ }
+ else
+ {
+ throw Exception("Error during transformation of struct: " + _returnStruct);
+ }
}
- shadergen.emitScopeEnd(stage, true);
- shadergen.emitLineBreak(stage);
+ else
+ {
- // Begin function signature.
- shadergen.emitLine(_returnStruct + " " + _functionName, stage, false);
+ // Define the output struct.
+ shadergen.emitLine("struct " + _returnStruct, stage, false);
+ shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS);
+ for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets())
+ {
+ shadergen.emitLine(syntax.getTypeName(output->getType()) + " mxp_" + output->getName(), stage);
+ }
+ shadergen.emitScopeEnd(stage, true);
+ shadergen.emitLineBreak(stage);
+
+ // Begin function signature.
+ shadergen.emitLine(_returnStruct + " " + _functionName, stage, false);
+ }
}
else
{
diff --git a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h
index 4ac9a5add3..ec33c0a7e3 100644
--- a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h
+++ b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h
@@ -12,6 +12,20 @@
MATERIALX_NAMESPACE_BEGIN
+/// Generator context data class to pass strings.
+class GenUserDataString : public GenUserData
+{
+ public:
+ GenUserDataString(const std::string& value) : _value(value) {}
+ const string& getValue() const { return _value; }
+
+ private:
+ string _value;
+};
+
+/// Shared pointer to a GenUserDataString
+using GenUserDataStringPtr = std::shared_ptr;
+
/// Compound node implementation
class MX_GENMDL_API CompoundNodeMdl : public CompoundNode
{
@@ -22,10 +36,15 @@ class MX_GENMDL_API CompoundNodeMdl : public CompoundNode
void emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override;
void emitFunctionCall(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override;
+ bool unrollReturnStructMembers() const { return _unrollReturnStructMembers; }
+
protected:
void emitFunctionSignature(const ShaderNode& node, GenContext& context, ShaderStage& stage) const;
string _returnStruct;
+ bool _unrollReturnStructMembers = false;
+
+ static const string GEN_USER_DATA_RETURN_STRUCT_FIELD_NAME;
};
MATERIALX_NAMESPACE_END