diff --git a/resources/Materials/TestSuite/stdlib/nodegraph_inputs/cascade_nodegraphs.mtlx b/resources/Materials/TestSuite/stdlib/nodegraph_inputs/cascade_nodegraphs.mtlx
new file mode 100644
index 0000000000..fada0b4114
--- /dev/null
+++ b/resources/Materials/TestSuite/stdlib/nodegraph_inputs/cascade_nodegraphs.mtlx
@@ -0,0 +1,69 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/resources/Materials/TestSuite/stdlib/nodegraph_inputs/nodegraph_nodegraph.mtlx b/resources/Materials/TestSuite/stdlib/nodegraph_inputs/nodegraph_nodegraph.mtlx
new file mode 100644
index 0000000000..c91b8d8393
--- /dev/null
+++ b/resources/Materials/TestSuite/stdlib/nodegraph_inputs/nodegraph_nodegraph.mtlx
@@ -0,0 +1,124 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/resources/Materials/TestSuite/stdlib/upgrade/1_38_parameter_to_input.mtlx b/resources/Materials/TestSuite/stdlib/upgrade/1_38_parameter_to_input.mtlx
index 18c1925d8b..b44b473fd4 100644
--- a/resources/Materials/TestSuite/stdlib/upgrade/1_38_parameter_to_input.mtlx
+++ b/resources/Materials/TestSuite/stdlib/upgrade/1_38_parameter_to_input.mtlx
@@ -2,10 +2,10 @@
-
+
-
+
@@ -15,8 +15,8 @@
-
-
+
+
@@ -33,12 +33,12 @@
-
+
-
+
@@ -63,14 +63,14 @@
-
+
-
+
diff --git a/source/MaterialXCore/Interface.cpp b/source/MaterialXCore/Interface.cpp
index bf0a1c7e2d..ffc84fafe2 100644
--- a/source/MaterialXCore/Interface.cpp
+++ b/source/MaterialXCore/Interface.cpp
@@ -238,20 +238,33 @@ OutputPtr Input::getConnectedOutput() const
// Look for output on a node
else if (hasNodeName())
{
- ConstGraphElementPtr graph = getAncestorOfType();
- NodePtr node = graph ? graph->getNode(getNodeName()) : nullptr;
- if (node)
+ const string& nodeName = getNodeName();
+ ConstElementPtr startingElement = getParent();
+ if (startingElement)
{
- std::vector outputs = node->getOutputs();
- if (!outputs.empty())
+ // Look for a node reference above the nodegraph if input is a direct child.
+ if (startingElement->isA())
{
- if (outputString.empty())
- {
- result = outputs[0];
- }
- else
+ startingElement = startingElement->getParent();
+ }
+ if (startingElement)
+ {
+ ConstGraphElementPtr graph = startingElement->getAncestorOfType();
+ NodePtr node = graph ? graph->getNode(nodeName) : nullptr;
+ if (node)
{
- result = node->getOutput(outputString);
+ std::vector outputs = node->getOutputs();
+ if (!outputs.empty())
+ {
+ if (outputString.empty())
+ {
+ result = outputs[0];
+ }
+ else
+ {
+ result = node->getOutput(outputString);
+ }
+ }
}
}
}
@@ -264,8 +277,29 @@ OutputPtr Input::getConnectedOutput() const
return result;
}
+InputPtr Input::getConnectedInterface() const
+{
+ const string& interfaceName = getInterfaceName();
+ if (!interfaceName.empty())
+ {
+ ConstNodeGraphPtr graph = getAncestorOfType();
+ if (graph)
+ {
+ return graph->getInput(interfaceName);
+ }
+ }
+ return nullptr;
+}
+
NodePtr Input::getConnectedNode() const
{
+ // Traverse through interface names to nodegraph input
+ InputPtr graphInput = getConnectedInterface();
+ if (graphInput && (graphInput->hasNodeName() || graphInput->hasNodeGraphString()))
+ {
+ return graphInput->getConnectedNode();
+ }
+
OutputPtr output = getConnectedOutput();
if (output)
{
@@ -275,11 +309,24 @@ NodePtr Input::getConnectedNode() const
}
if (hasNodeName())
{
- ConstGraphElementPtr graph = getAncestorOfType();
- NodePtr node = graph ? graph->getNode(getNodeName()) : nullptr;
- if (node)
+ const string& nodeName = getNodeName();
+ ConstElementPtr startingElement = getParent();
+ if (startingElement)
{
- return node;
+ // Look for a node reference above the nodegraph if input is a direct child.
+ if (startingElement->isA())
+ {
+ startingElement = startingElement->getParent();
+ }
+ if (startingElement)
+ {
+ ConstGraphElementPtr graph = startingElement->getAncestorOfType();
+ NodePtr node = graph ? graph->getNode(nodeName) : nullptr;
+ if (node)
+ {
+ return node;
+ }
+ }
}
}
return PortElement::getConnectedNode();
@@ -303,6 +350,11 @@ bool Input::validate(string* message) const
{
validateRequire(getDefaultGeomProp() != nullptr, res, message, "Invalid defaultgeomprop string");
}
+ InputPtr interfaceInput = getConnectedInterface();
+ if (interfaceInput)
+ {
+ return interfaceInput->validate() && res;
+ }
return PortElement::validate(message) && res;
}
diff --git a/source/MaterialXCore/Interface.h b/source/MaterialXCore/Interface.h
index 992fc1e0b7..cb966968ed 100644
--- a/source/MaterialXCore/Interface.h
+++ b/source/MaterialXCore/Interface.h
@@ -218,6 +218,10 @@ class Input : public PortElement
/// @name Traversal
/// @{
+ /// Return the input on the parent graph corresponding to the interface name
+ /// for the element.
+ InputPtr getConnectedInterface() const;
+
/// Return the output, if any, to which this element is connected.
OutputPtr getConnectedOutput() const;
diff --git a/source/MaterialXCore/Node.cpp b/source/MaterialXCore/Node.cpp
index 45e8494481..1f5f5fa373 100644
--- a/source/MaterialXCore/Node.cpp
+++ b/source/MaterialXCore/Node.cpp
@@ -142,7 +142,20 @@ OutputPtr Node::getNodeDefOutput(ElementPtr connectingElement)
OutputPtr output = OutputPtr();
if (connectedInput)
{
- output = connectedInput->getConnectedOutput();
+ InputPtr interfaceInput = nullptr;
+ if (connectedInput->hasInterfaceName())
+ {
+ interfaceInput = connectedInput->getConnectedInterface();
+ if (interfaceInput)
+ {
+ outputName = &(interfaceInput->getOutputString());
+ output = interfaceInput->getConnectedOutput();
+ }
+ }
+ if (!interfaceInput)
+ {
+ output = connectedInput->getConnectedOutput();
+ }
}
if (output)
{
@@ -651,6 +664,33 @@ bool NodeGraph::validate(string* message) const
validateRequire(getOutputCount() == nodeDef->getActiveOutputs().size(), res, message, "NodeGraph implementation has a different number of outputs than its NodeDef");
}
}
+ // Check interfaces on nodegraphs which are not definitions
+ if (!hasNodeDefString())
+ {
+ for (NodePtr node : getNodes())
+ {
+ for (InputPtr input : node->getInputs())
+ {
+ const string& interfaceName = input->getInterfaceName();
+ if (!interfaceName.empty())
+ {
+ InputPtr interfaceInput = input->getConnectedInterface();
+ validateRequire(interfaceInput != nullptr, res, message, "NodeGraph interface input: \"" + interfaceName + "\" does not exist on nodegraph");
+ string connectedNodeName = interfaceInput ? interfaceInput->getNodeName() : EMPTY_STRING;
+ if (connectedNodeName.empty())
+ {
+ connectedNodeName = interfaceInput->getNodeGraphString();
+ }
+ if (interfaceInput && !connectedNodeName.empty())
+ {
+ NodePtr connectedNode = input->getConnectedNode();
+ validateRequire(connectedNode != nullptr, res, message, "Nodegraph input: \"" + interfaceInput->getNamePath() +
+ "\" specifies connection to non existent node: \"" + connectedNodeName + "\"");
+ }
+ }
+ }
+ }
+ }
return GraphElement::validate(message) && res;
}
diff --git a/source/MaterialXCore/Traversal.h b/source/MaterialXCore/Traversal.h
index 7953d1965c..469ef76781 100644
--- a/source/MaterialXCore/Traversal.h
+++ b/source/MaterialXCore/Traversal.h
@@ -15,11 +15,9 @@ namespace MaterialX
{
class Element;
-class Material;
using ElementPtr = shared_ptr;
using ConstElementPtr = shared_ptr;
-using ConstMaterialPtr = shared_ptr;
/// @class Edge
/// An edge between two connected Elements, returned during graph traversal.
diff --git a/source/MaterialXRuntime/RtFileIo.cpp b/source/MaterialXRuntime/RtFileIo.cpp
index f7e95a52fd..fbf42ca1b1 100644
--- a/source/MaterialXRuntime/RtFileIo.cpp
+++ b/source/MaterialXRuntime/RtFileIo.cpp
@@ -210,40 +210,53 @@ namespace
output->connect(input);
}
- void createNodeConnections(const vector& nodeElements, PvtPrim* parent, PvtStage* stage, const PvtRenamingMapper& mapper)
+ void createNodeConnection(InterfaceElementPtr nodeElem, PvtPrim* parent, PvtStage* stage, const PvtRenamingMapper& mapper)
{
- for (const NodePtr& nodeElem : nodeElements)
+ PvtPrim* node = findPrimOrThrow(RtToken(nodeElem->getName()), parent, mapper);
+ for (InputPtr elemInput : nodeElem->getInputs())
{
- PvtPrim* node = findPrimOrThrow(RtToken(nodeElem->getName()), parent, mapper);
- for (const InputPtr& elemInput : nodeElem->getInputs())
+ PvtInput* input = findInputOrThrow(RtToken(elemInput->getName()), node);
+ string connectedNodeName = elemInput->getNodeName();
+ if (connectedNodeName.empty())
{
- PvtInput* input = findInputOrThrow(RtToken(elemInput->getName()), node);
- string connectedNodeName = elemInput->getNodeName();
- if (connectedNodeName.empty())
- {
- connectedNodeName = elemInput->getNodeGraphString();
- }
- if (!connectedNodeName.empty())
+ connectedNodeName = elemInput->getNodeGraphString();
+ }
+ if (!connectedNodeName.empty())
+ {
+ PvtPrim* connectedNode = findPrimOrThrow(RtToken(connectedNodeName), parent, mapper);
+ RtToken outputName(elemInput->getOutputString());
+ if (outputName == EMPTY_TOKEN && connectedNode)
{
- PvtPrim* connectedNode = findPrimOrThrow(RtToken(connectedNodeName), parent, mapper);
- RtToken outputName(elemInput->getOutputString());
- if (outputName == EMPTY_TOKEN && connectedNode)
+ RtNode rtConnectedNode(connectedNode->hnd());
+ auto output = rtConnectedNode.getOutput();
+ if (output)
{
- RtNode rtConnectedNode(connectedNode->hnd());
- auto output = rtConnectedNode.getOutput();
- if (output)
- {
- outputName = output.getName();
- }
+ outputName = output.getName();
}
- PvtOutput* output = findOutputOrThrow(outputName, connectedNode);
-
- createConnection(output, input, elemInput->getChannels(), stage);
}
+ PvtOutput* output = findOutputOrThrow(outputName, connectedNode);
+
+ createConnection(output, input, elemInput->getChannels(), stage);
}
}
}
+ void createNodeConnections(const vector& nodeElements, PvtPrim* parent, PvtStage* stage, const PvtRenamingMapper& mapper)
+ {
+ for (auto nodeElem : nodeElements)
+ {
+ createNodeConnection(nodeElem->asA(), parent, stage, mapper);
+ }
+ }
+
+ void createNodeGraphConnections(const vector& nodeElements, PvtPrim* parent, PvtStage* stage, const PvtRenamingMapper& mapper)
+ {
+ for (auto nodeElem : nodeElements)
+ {
+ createNodeConnection(nodeElem->asA(), parent, stage, mapper);
+ }
+ }
+
PvtPrim* readNodeDef(const NodeDefPtr& src, PvtStage* stage)
{
const RtToken name(src->getName());
@@ -730,6 +743,9 @@ namespace
// Create connections between all root level nodes.
createNodeConnections(doc->getNodes(), stage->getRootPrim(), stage, mapper);
+ // Create connections between all nodegraphs
+ createNodeGraphConnections(doc->getNodeGraphs(), stage->getRootPrim(), stage, mapper);
+
// Read look information
if (!options || options->readLookInformation)
{
@@ -917,9 +933,16 @@ namespace
{
// Write connections to upstream nodes.
RtOutput source = nodegraphInput.getConnection();
- RtNode sourceNode = source.getParent();
- input->setNodeName(sourceNode.getName());
- if (sourceNode.numOutputs() > 1)
+ RtPrim sourcePrim = source.getParent();
+ if (sourcePrim.hasApi())
+ {
+ input->setNodeGraphString(sourcePrim.getName());
+ }
+ else
+ {
+ input->setNodeName(sourcePrim.getName());
+ }
+ if (sourcePrim.numOutputs() > 1)
{
input->setOutputString(source.getName());
}
diff --git a/source/MaterialXTest/MaterialXCore/Traversal.cpp b/source/MaterialXTest/MaterialXCore/Traversal.cpp
index 38e22068f8..2b7f473a4d 100644
--- a/source/MaterialXTest/MaterialXCore/Traversal.cpp
+++ b/source/MaterialXTest/MaterialXCore/Traversal.cpp
@@ -6,10 +6,12 @@
#include
#include
+#include
+#include
namespace mx = MaterialX;
-TEST_CASE("Traversal", "[traversal]")
+TEST_CASE("IntraGraph Traversal", "[traversal]")
{
// Test null iterators.
mx::TreeIterator nullTree = mx::NULL_TREE_ITERATOR;
@@ -171,3 +173,28 @@ TEST_CASE("Traversal", "[traversal]")
REQUIRE(!output->hasUpstreamCycle());
REQUIRE(doc->validate());
}
+
+TEST_CASE("InterGraph Tranversal", "[traversal]")
+{
+ mx::FileSearchPath searchPath;
+ const mx::FilePath currentPath = mx::FilePath::getCurrentPath();
+ searchPath.append(currentPath / mx::FilePath("libraries"));
+
+ mx::DocumentPtr doc = mx::createDocument();
+ mx::loadLibraries({ "stdlib", "pbrlib", "bxdf" }, searchPath, doc);
+
+ mx::FilePath testFile = mx::FilePath::getCurrentPath() / mx::FilePath("resources/Materials/TestSuite/stdlib/nodegraph_inputs/nodegraph_nodegraph.mtlx");
+ mx::readFromXmlFile(doc, testFile, searchPath);
+ REQUIRE(doc->validate());
+
+ for (mx::NodeGraphPtr graph : doc->getNodeGraphs())
+ {
+ for (mx::InputPtr interfaceInput : graph->getInputs())
+ {
+ if (!interfaceInput->getNodeName().empty() || !interfaceInput->getNodeGraphString().empty())
+ {
+ REQUIRE(interfaceInput->getConnectedNode() != nullptr);
+ }
+ }
+ }
+}