diff --git a/resources/Materials/TestSuite/stdlib/nodegraph_inputs/cascade_nodegraphs.mtlx b/resources/Materials/TestSuite/stdlib/nodegraph_inputs/cascade_nodegraphs.mtlx new file mode 100644 index 0000000000..fada0b4114 --- /dev/null +++ b/resources/Materials/TestSuite/stdlib/nodegraph_inputs/cascade_nodegraphs.mtlx @@ -0,0 +1,69 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/resources/Materials/TestSuite/stdlib/nodegraph_inputs/nodegraph_nodegraph.mtlx b/resources/Materials/TestSuite/stdlib/nodegraph_inputs/nodegraph_nodegraph.mtlx new file mode 100644 index 0000000000..c91b8d8393 --- /dev/null +++ b/resources/Materials/TestSuite/stdlib/nodegraph_inputs/nodegraph_nodegraph.mtlx @@ -0,0 +1,124 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/resources/Materials/TestSuite/stdlib/upgrade/1_38_parameter_to_input.mtlx b/resources/Materials/TestSuite/stdlib/upgrade/1_38_parameter_to_input.mtlx index 18c1925d8b..b44b473fd4 100644 --- a/resources/Materials/TestSuite/stdlib/upgrade/1_38_parameter_to_input.mtlx +++ b/resources/Materials/TestSuite/stdlib/upgrade/1_38_parameter_to_input.mtlx @@ -2,10 +2,10 @@ - + - + @@ -15,8 +15,8 @@ - - + + @@ -33,12 +33,12 @@ - + - + @@ -63,14 +63,14 @@ - + - + diff --git a/source/MaterialXCore/Interface.cpp b/source/MaterialXCore/Interface.cpp index bf0a1c7e2d..ffc84fafe2 100644 --- a/source/MaterialXCore/Interface.cpp +++ b/source/MaterialXCore/Interface.cpp @@ -238,20 +238,33 @@ OutputPtr Input::getConnectedOutput() const // Look for output on a node else if (hasNodeName()) { - ConstGraphElementPtr graph = getAncestorOfType(); - NodePtr node = graph ? graph->getNode(getNodeName()) : nullptr; - if (node) + const string& nodeName = getNodeName(); + ConstElementPtr startingElement = getParent(); + if (startingElement) { - std::vector outputs = node->getOutputs(); - if (!outputs.empty()) + // Look for a node reference above the nodegraph if input is a direct child. + if (startingElement->isA()) { - if (outputString.empty()) - { - result = outputs[0]; - } - else + startingElement = startingElement->getParent(); + } + if (startingElement) + { + ConstGraphElementPtr graph = startingElement->getAncestorOfType(); + NodePtr node = graph ? graph->getNode(nodeName) : nullptr; + if (node) { - result = node->getOutput(outputString); + std::vector outputs = node->getOutputs(); + if (!outputs.empty()) + { + if (outputString.empty()) + { + result = outputs[0]; + } + else + { + result = node->getOutput(outputString); + } + } } } } @@ -264,8 +277,29 @@ OutputPtr Input::getConnectedOutput() const return result; } +InputPtr Input::getConnectedInterface() const +{ + const string& interfaceName = getInterfaceName(); + if (!interfaceName.empty()) + { + ConstNodeGraphPtr graph = getAncestorOfType(); + if (graph) + { + return graph->getInput(interfaceName); + } + } + return nullptr; +} + NodePtr Input::getConnectedNode() const { + // Traverse through interface names to nodegraph input + InputPtr graphInput = getConnectedInterface(); + if (graphInput && (graphInput->hasNodeName() || graphInput->hasNodeGraphString())) + { + return graphInput->getConnectedNode(); + } + OutputPtr output = getConnectedOutput(); if (output) { @@ -275,11 +309,24 @@ NodePtr Input::getConnectedNode() const } if (hasNodeName()) { - ConstGraphElementPtr graph = getAncestorOfType(); - NodePtr node = graph ? graph->getNode(getNodeName()) : nullptr; - if (node) + const string& nodeName = getNodeName(); + ConstElementPtr startingElement = getParent(); + if (startingElement) { - return node; + // Look for a node reference above the nodegraph if input is a direct child. + if (startingElement->isA()) + { + startingElement = startingElement->getParent(); + } + if (startingElement) + { + ConstGraphElementPtr graph = startingElement->getAncestorOfType(); + NodePtr node = graph ? graph->getNode(nodeName) : nullptr; + if (node) + { + return node; + } + } } } return PortElement::getConnectedNode(); @@ -303,6 +350,11 @@ bool Input::validate(string* message) const { validateRequire(getDefaultGeomProp() != nullptr, res, message, "Invalid defaultgeomprop string"); } + InputPtr interfaceInput = getConnectedInterface(); + if (interfaceInput) + { + return interfaceInput->validate() && res; + } return PortElement::validate(message) && res; } diff --git a/source/MaterialXCore/Interface.h b/source/MaterialXCore/Interface.h index 992fc1e0b7..cb966968ed 100644 --- a/source/MaterialXCore/Interface.h +++ b/source/MaterialXCore/Interface.h @@ -218,6 +218,10 @@ class Input : public PortElement /// @name Traversal /// @{ + /// Return the input on the parent graph corresponding to the interface name + /// for the element. + InputPtr getConnectedInterface() const; + /// Return the output, if any, to which this element is connected. OutputPtr getConnectedOutput() const; diff --git a/source/MaterialXCore/Node.cpp b/source/MaterialXCore/Node.cpp index 45e8494481..1f5f5fa373 100644 --- a/source/MaterialXCore/Node.cpp +++ b/source/MaterialXCore/Node.cpp @@ -142,7 +142,20 @@ OutputPtr Node::getNodeDefOutput(ElementPtr connectingElement) OutputPtr output = OutputPtr(); if (connectedInput) { - output = connectedInput->getConnectedOutput(); + InputPtr interfaceInput = nullptr; + if (connectedInput->hasInterfaceName()) + { + interfaceInput = connectedInput->getConnectedInterface(); + if (interfaceInput) + { + outputName = &(interfaceInput->getOutputString()); + output = interfaceInput->getConnectedOutput(); + } + } + if (!interfaceInput) + { + output = connectedInput->getConnectedOutput(); + } } if (output) { @@ -651,6 +664,33 @@ bool NodeGraph::validate(string* message) const validateRequire(getOutputCount() == nodeDef->getActiveOutputs().size(), res, message, "NodeGraph implementation has a different number of outputs than its NodeDef"); } } + // Check interfaces on nodegraphs which are not definitions + if (!hasNodeDefString()) + { + for (NodePtr node : getNodes()) + { + for (InputPtr input : node->getInputs()) + { + const string& interfaceName = input->getInterfaceName(); + if (!interfaceName.empty()) + { + InputPtr interfaceInput = input->getConnectedInterface(); + validateRequire(interfaceInput != nullptr, res, message, "NodeGraph interface input: \"" + interfaceName + "\" does not exist on nodegraph"); + string connectedNodeName = interfaceInput ? interfaceInput->getNodeName() : EMPTY_STRING; + if (connectedNodeName.empty()) + { + connectedNodeName = interfaceInput->getNodeGraphString(); + } + if (interfaceInput && !connectedNodeName.empty()) + { + NodePtr connectedNode = input->getConnectedNode(); + validateRequire(connectedNode != nullptr, res, message, "Nodegraph input: \"" + interfaceInput->getNamePath() + + "\" specifies connection to non existent node: \"" + connectedNodeName + "\""); + } + } + } + } + } return GraphElement::validate(message) && res; } diff --git a/source/MaterialXCore/Traversal.h b/source/MaterialXCore/Traversal.h index 7953d1965c..469ef76781 100644 --- a/source/MaterialXCore/Traversal.h +++ b/source/MaterialXCore/Traversal.h @@ -15,11 +15,9 @@ namespace MaterialX { class Element; -class Material; using ElementPtr = shared_ptr; using ConstElementPtr = shared_ptr; -using ConstMaterialPtr = shared_ptr; /// @class Edge /// An edge between two connected Elements, returned during graph traversal. diff --git a/source/MaterialXRuntime/RtFileIo.cpp b/source/MaterialXRuntime/RtFileIo.cpp index f7e95a52fd..fbf42ca1b1 100644 --- a/source/MaterialXRuntime/RtFileIo.cpp +++ b/source/MaterialXRuntime/RtFileIo.cpp @@ -210,40 +210,53 @@ namespace output->connect(input); } - void createNodeConnections(const vector& nodeElements, PvtPrim* parent, PvtStage* stage, const PvtRenamingMapper& mapper) + void createNodeConnection(InterfaceElementPtr nodeElem, PvtPrim* parent, PvtStage* stage, const PvtRenamingMapper& mapper) { - for (const NodePtr& nodeElem : nodeElements) + PvtPrim* node = findPrimOrThrow(RtToken(nodeElem->getName()), parent, mapper); + for (InputPtr elemInput : nodeElem->getInputs()) { - PvtPrim* node = findPrimOrThrow(RtToken(nodeElem->getName()), parent, mapper); - for (const InputPtr& elemInput : nodeElem->getInputs()) + PvtInput* input = findInputOrThrow(RtToken(elemInput->getName()), node); + string connectedNodeName = elemInput->getNodeName(); + if (connectedNodeName.empty()) { - PvtInput* input = findInputOrThrow(RtToken(elemInput->getName()), node); - string connectedNodeName = elemInput->getNodeName(); - if (connectedNodeName.empty()) - { - connectedNodeName = elemInput->getNodeGraphString(); - } - if (!connectedNodeName.empty()) + connectedNodeName = elemInput->getNodeGraphString(); + } + if (!connectedNodeName.empty()) + { + PvtPrim* connectedNode = findPrimOrThrow(RtToken(connectedNodeName), parent, mapper); + RtToken outputName(elemInput->getOutputString()); + if (outputName == EMPTY_TOKEN && connectedNode) { - PvtPrim* connectedNode = findPrimOrThrow(RtToken(connectedNodeName), parent, mapper); - RtToken outputName(elemInput->getOutputString()); - if (outputName == EMPTY_TOKEN && connectedNode) + RtNode rtConnectedNode(connectedNode->hnd()); + auto output = rtConnectedNode.getOutput(); + if (output) { - RtNode rtConnectedNode(connectedNode->hnd()); - auto output = rtConnectedNode.getOutput(); - if (output) - { - outputName = output.getName(); - } + outputName = output.getName(); } - PvtOutput* output = findOutputOrThrow(outputName, connectedNode); - - createConnection(output, input, elemInput->getChannels(), stage); } + PvtOutput* output = findOutputOrThrow(outputName, connectedNode); + + createConnection(output, input, elemInput->getChannels(), stage); } } } + void createNodeConnections(const vector& nodeElements, PvtPrim* parent, PvtStage* stage, const PvtRenamingMapper& mapper) + { + for (auto nodeElem : nodeElements) + { + createNodeConnection(nodeElem->asA(), parent, stage, mapper); + } + } + + void createNodeGraphConnections(const vector& nodeElements, PvtPrim* parent, PvtStage* stage, const PvtRenamingMapper& mapper) + { + for (auto nodeElem : nodeElements) + { + createNodeConnection(nodeElem->asA(), parent, stage, mapper); + } + } + PvtPrim* readNodeDef(const NodeDefPtr& src, PvtStage* stage) { const RtToken name(src->getName()); @@ -730,6 +743,9 @@ namespace // Create connections between all root level nodes. createNodeConnections(doc->getNodes(), stage->getRootPrim(), stage, mapper); + // Create connections between all nodegraphs + createNodeGraphConnections(doc->getNodeGraphs(), stage->getRootPrim(), stage, mapper); + // Read look information if (!options || options->readLookInformation) { @@ -917,9 +933,16 @@ namespace { // Write connections to upstream nodes. RtOutput source = nodegraphInput.getConnection(); - RtNode sourceNode = source.getParent(); - input->setNodeName(sourceNode.getName()); - if (sourceNode.numOutputs() > 1) + RtPrim sourcePrim = source.getParent(); + if (sourcePrim.hasApi()) + { + input->setNodeGraphString(sourcePrim.getName()); + } + else + { + input->setNodeName(sourcePrim.getName()); + } + if (sourcePrim.numOutputs() > 1) { input->setOutputString(source.getName()); } diff --git a/source/MaterialXTest/MaterialXCore/Traversal.cpp b/source/MaterialXTest/MaterialXCore/Traversal.cpp index 38e22068f8..2b7f473a4d 100644 --- a/source/MaterialXTest/MaterialXCore/Traversal.cpp +++ b/source/MaterialXTest/MaterialXCore/Traversal.cpp @@ -6,10 +6,12 @@ #include #include +#include +#include namespace mx = MaterialX; -TEST_CASE("Traversal", "[traversal]") +TEST_CASE("IntraGraph Traversal", "[traversal]") { // Test null iterators. mx::TreeIterator nullTree = mx::NULL_TREE_ITERATOR; @@ -171,3 +173,28 @@ TEST_CASE("Traversal", "[traversal]") REQUIRE(!output->hasUpstreamCycle()); REQUIRE(doc->validate()); } + +TEST_CASE("InterGraph Tranversal", "[traversal]") +{ + mx::FileSearchPath searchPath; + const mx::FilePath currentPath = mx::FilePath::getCurrentPath(); + searchPath.append(currentPath / mx::FilePath("libraries")); + + mx::DocumentPtr doc = mx::createDocument(); + mx::loadLibraries({ "stdlib", "pbrlib", "bxdf" }, searchPath, doc); + + mx::FilePath testFile = mx::FilePath::getCurrentPath() / mx::FilePath("resources/Materials/TestSuite/stdlib/nodegraph_inputs/nodegraph_nodegraph.mtlx"); + mx::readFromXmlFile(doc, testFile, searchPath); + REQUIRE(doc->validate()); + + for (mx::NodeGraphPtr graph : doc->getNodeGraphs()) + { + for (mx::InputPtr interfaceInput : graph->getInputs()) + { + if (!interfaceInput->getNodeName().empty() || !interfaceInput->getNodeGraphString().empty()) + { + REQUIRE(interfaceInput->getConnectedNode() != nullptr); + } + } + } +}