Skip to content

Commit

Permalink
Data Library API
Browse files Browse the repository at this point in the history
Introduce methods to register a data library for a document
  • Loading branch information
ashwinbhat committed Oct 8, 2024
1 parent 19d6928 commit 7e61f63
Show file tree
Hide file tree
Showing 14 changed files with 125 additions and 20 deletions.
4 changes: 3 additions & 1 deletion source/MaterialXCore/Definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,10 @@ StringVec TargetDef::getMatchingTargets() const

vector<UnitDefPtr> UnitTypeDef::getUnitDefs() const
{
const auto datalibrary = getDocument()->hasDataLibrary() ? getDocument()->getRegisteredDataLibrary() : getDocument();

vector<UnitDefPtr> unitDefs;
for (UnitDefPtr unitDef : getDocument()->getChildrenOfType<UnitDef>())
for (UnitDefPtr unitDef : datalibrary->getChildrenOfType<UnitDef>())
{
if (unitDef->getUnitType() == _name)
{
Expand Down
17 changes: 17 additions & 0 deletions source/MaterialXCore/Document.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,14 @@ vector<OutputPtr> Document::getMaterialOutputs() const

vector<NodeDefPtr> Document::getMatchingNodeDefs(const string& nodeName) const
{
// Return all nodedefs from datalibrary if available
if (_dataLibrary)
{
auto datalibrarynodes = _dataLibrary->getMatchingNodeDefs(nodeName);
if (!datalibrarynodes.empty())
return datalibrarynodes;
}

// Refresh the cache.
_cache->refresh();

Expand All @@ -373,6 +381,15 @@ vector<NodeDefPtr> Document::getMatchingNodeDefs(const string& nodeName) const

vector<InterfaceElementPtr> Document::getMatchingImplementations(const string& nodeDef) const
{

// Return all implementations from datalibrary if available
if (_dataLibrary)
{
auto datalibrarynodes = _dataLibrary->getMatchingImplementations(nodeDef);
if (!datalibrarynodes.empty())
return datalibrarynodes;
}

// Refresh the cache.
_cache->refresh();

Expand Down
48 changes: 48 additions & 0 deletions source/MaterialXCore/Document.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,16 @@ class MX_CORE_API Document : public GraphElement
/// Return the NodeGraph, if any, with the given name.
NodeGraphPtr getNodeGraph(const string& name) const
{
if (_dataLibrary)
return _dataLibrary->getChildOfType<NodeGraph>(name);
return getChildOfType<NodeGraph>(name);
}

/// Return a vector of all NodeGraph elements in the document.
vector<NodeGraphPtr> getNodeGraphs() const
{
if (_dataLibrary)
return _dataLibrary->getChildrenOfType<NodeGraph>();
return getChildrenOfType<NodeGraph>();
}

Expand Down Expand Up @@ -345,12 +349,16 @@ class MX_CORE_API Document : public GraphElement
/// Return the NodeDef, if any, with the given name.
NodeDefPtr getNodeDef(const string& name) const
{
if (_dataLibrary)
return _dataLibrary->getChildOfType<NodeDef>(name);
return getChildOfType<NodeDef>(name);
}

/// Return a vector of all NodeDef elements in the document.
vector<NodeDefPtr> getNodeDefs() const
{
if (_dataLibrary)
return _dataLibrary->getChildrenOfType<NodeDef>();
return getChildrenOfType<NodeDef>();
}

Expand Down Expand Up @@ -380,12 +388,16 @@ class MX_CORE_API Document : public GraphElement
/// Return the AttributeDef, if any, with the given name.
AttributeDefPtr getAttributeDef(const string& name) const
{
if (_dataLibrary)
return _dataLibrary->getChildOfType<AttributeDef>(name);
return getChildOfType<AttributeDef>(name);
}

/// Return a vector of all AttributeDef elements in the document.
vector<AttributeDefPtr> getAttributeDefs() const
{
if (_dataLibrary)
return _dataLibrary->getChildrenOfType<AttributeDef>();
return getChildrenOfType<AttributeDef>();
}

Expand All @@ -412,12 +424,16 @@ class MX_CORE_API Document : public GraphElement
/// Return the AttributeDef, if any, with the given name.
TargetDefPtr getTargetDef(const string& name) const
{
if (_dataLibrary)
return _dataLibrary->getChildOfType<TargetDef>(name);
return getChildOfType<TargetDef>(name);
}

/// Return a vector of all TargetDef elements in the document.
vector<TargetDefPtr> getTargetDefs() const
{
if (_dataLibrary)
return _dataLibrary->getChildrenOfType<TargetDef>();
return getChildrenOfType<TargetDef>();
}

Expand Down Expand Up @@ -508,12 +524,16 @@ class MX_CORE_API Document : public GraphElement
/// Return the Implementation, if any, with the given name.
ImplementationPtr getImplementation(const string& name) const
{
if (_dataLibrary)
return _dataLibrary->getChildOfType<Implementation>(name);
return getChildOfType<Implementation>(name);
}

/// Return a vector of all Implementation elements in the document.
vector<ImplementationPtr> getImplementations() const
{
if (_dataLibrary)
return _dataLibrary->getChildrenOfType<Implementation>();
return getChildrenOfType<Implementation>();
}

Expand Down Expand Up @@ -665,6 +685,32 @@ class MX_CORE_API Document : public GraphElement

/// @}

/// @name MaterialX data library
/// @{

/// Register the given document as MaterialX data library for document
/// The MaterialX data library can be created using the loadLibraries utility
/// For improved performance it is recommended the data library
/// is on the document instead of importing it
/// @param Data Library document to register.
void registerDataLibrary(ConstDocumentPtr dataLibrary)
{
_dataLibrary = dataLibrary;
}

/// Gets the registered data library
ConstDocumentPtr getRegisteredDataLibrary() const
{
return _dataLibrary;
}

/// Returns true if a data library is registered.
bool hasDataLibrary() const
{
return (_dataLibrary != nullptr);
}
/// @}

//
// These are deprecated wrappers for older versions of the function interfaces in this module.
// Clients using these interfaces should update them to the latest API.
Expand All @@ -680,6 +726,8 @@ class MX_CORE_API Document : public GraphElement
private:
class Cache;
std::unique_ptr<Cache> _cache;
// Data library for the document
ConstDocumentPtr _dataLibrary;
};

/// Create a new Document.
Expand Down
5 changes: 4 additions & 1 deletion source/MaterialXCore/Element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,10 @@ bool ValueElement::validate(string* message) const
const string& unittype = getUnitType();
if (!unittype.empty())
{
unitTypeDef = getDocument()->getUnitTypeDef(unittype);

unitTypeDef = getDocument()->hasDataLibrary() ?
getDocument()->getRegisteredDataLibrary()->getUnitTypeDef(unittype) :
getDocument()->getUnitTypeDef(unittype);
validateRequire(unitTypeDef != nullptr, res, message, "Unit type definition does not exist in document");
}
}
Expand Down
4 changes: 3 additions & 1 deletion source/MaterialXCore/Interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ GeomPropDefPtr Input::getDefaultGeomProp() const
const string& defaultGeomProp = getAttribute(DEFAULT_GEOM_PROP_ATTRIBUTE);
if (!defaultGeomProp.empty())
{
ConstDocumentPtr doc = getDocument();
ConstDocumentPtr doc = getDocument()->hasDataLibrary() ?
getDocument()->getRegisteredDataLibrary() :
getDocument();
return doc->getChildOfType<GeomPropDef>(defaultGeomProp);
}
return nullptr;
Expand Down
29 changes: 24 additions & 5 deletions source/MaterialXCore/Node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,21 @@ string Node::getConnectedNodeName(const string& inputName) const

NodeDefPtr Node::getNodeDef(const string& target, bool allowRoughMatch) const
{
if (hasNodeDefString())
{
return resolveNameReference<NodeDef>(getNodeDefString());
}
vector<NodeDefPtr> nodeDefs = getDocument()->getMatchingNodeDefs(getQualifiedName(getCategory()));
vector<NodeDefPtr> secondary = getDocument()->getMatchingNodeDefs(getCategory());
vector<NodeDefPtr> roughMatches;
nodeDefs.insert(nodeDefs.end(), secondary.begin(), secondary.end());

// Search data library if available
if (getDocument()->hasDataLibrary())
{
vector<NodeDefPtr> libraryNodeDefs = getDocument()->getRegisteredDataLibrary()->getMatchingNodeDefs(getQualifiedName(getCategory()));
vector<NodeDefPtr> librarySecondardNodeDefs = getDocument()->getRegisteredDataLibrary()->getMatchingNodeDefs(getCategory());
nodeDefs.insert(nodeDefs.end(), libraryNodeDefs.begin(), libraryNodeDefs.end());
nodeDefs.insert(nodeDefs.end(), librarySecondardNodeDefs.begin(), librarySecondardNodeDefs.end());
}

vector<NodeDefPtr> roughMatches;

for (NodeDefPtr nodeDef : nodeDefs)
{
if (!targetStringsMatch(nodeDef->getTarget(), target) ||
Expand Down Expand Up @@ -714,6 +721,18 @@ NodeDefPtr NodeGraph::getNodeDef() const
}
}
}
// Check datalibrary if available
if (!nodedef && getDocument()->hasDataLibrary())
{
const auto datalibray = getDocument()->getRegisteredDataLibrary();
for (auto impl : datalibray->getImplementations())
{
if (impl->getNodeGraph() == getQualifiedName(getName()))
{
nodedef = impl->getNodeDef();
}
}
}
return nodedef;
}

Expand Down
11 changes: 8 additions & 3 deletions source/MaterialXGenShader/ShaderGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,15 @@ void ShaderGraph::addDefaultGeomNode(ShaderInput* input, const GeomPropDef& geom
// input here and ignore the type of the geomprop. They are required to have the same type.
string geomNodeDefName = "ND_" + geomprop.getGeomProp() + "_" + input->getType().getName();
NodeDefPtr geomNodeDef = _document->getNodeDef(geomNodeDefName);
if (!geomNodeDef)
if (!geomNodeDef && _document->hasDataLibrary())
{
throw ExceptionShaderGenError("Could not find a nodedef named '" + geomNodeDefName +
"' for defaultgeomprop on input '" + input->getFullName() + "'");
geomNodeDef = _document->getRegisteredDataLibrary()->getNodeDef(geomNodeDefName);
if (!geomNodeDef)
{

throw ExceptionShaderGenError("Could not find a nodedef named '" + geomNodeDefName +
"' for defaultgeomprop on input '" + input->getFullName() + "'");
}
}

ShaderNodePtr geomNode = ShaderNode::create(this, geomNodeName, *geomNodeDef, context);
Expand Down
6 changes: 5 additions & 1 deletion source/MaterialXTest/MaterialXCore/Document.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ TEST_CASE("Document", "[document]")
REQUIRE(customLibrary->validate());

// Import the custom library.
doc->importLibrary(customLibrary);
mx::DocumentPtr customdatalibrary = mx::createDocument();
customdatalibrary->importLibrary(customLibrary);

// Register data library
doc->registerDataLibrary(customdatalibrary);
mx::NodeGraphPtr importedNodeGraph = doc->getNodeGraph("custom:NG_custom");
mx::NodeDefPtr importedNodeDef = doc->getNodeDef("custom:ND_simpleSrf");
mx::ImplementationPtr importedImpl = doc->getImplementation("custom:IM_custom");
Expand Down
2 changes: 1 addition & 1 deletion source/MaterialXTest/MaterialXCore/Node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ TEST_CASE("Node Definition Creation", "[nodedef]")

mx::DocumentPtr doc = mx::createDocument();
mx::readFromXmlFile(doc, "resources/Materials/TestSuite/stdlib/definition/definition_from_nodegraph.mtlx", searchPath);
doc->importLibrary(stdlib);
doc->registerDataLibrary(stdlib);

mx::NodeGraphPtr graph = doc->getNodeGraph("test_colorcorrect");
REQUIRE(graph);
Expand Down
8 changes: 4 additions & 4 deletions source/MaterialXTest/MaterialXGenShader/GenShader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ TEST_CASE("GenShader: Transparency Regression Check", "[genshader]")
bool testValue = transparencyTest[i];

mx::DocumentPtr testDoc = mx::createDocument();
testDoc->importLibrary(libraries);
testDoc->registerDataLibrary(libraries);

try
{
Expand Down Expand Up @@ -207,7 +207,7 @@ void testDeterministicGeneration(mx::DocumentPtr libraries, mx::GenContext& cont
{
mx::DocumentPtr testDoc = mx::createDocument();
mx::readFromXmlFile(testDoc, testFile);
testDoc->importLibrary(libraries);
testDoc->registerDataLibrary(libraries);

// Keep the document alive to make sure
// new memory is allocated for each run
Expand Down Expand Up @@ -272,7 +272,7 @@ void checkPixelDependencies(mx::DocumentPtr libraries, mx::GenContext& context)

mx::DocumentPtr testDoc = mx::createDocument();
mx::readFromXmlFile(testDoc, testFile);
testDoc->importLibrary(libraries);
testDoc->registerDataLibrary(libraries);

mx::ElementPtr element = testDoc->getChild(testElement);
CHECK(element);
Expand Down Expand Up @@ -385,7 +385,7 @@ TEST_CASE("GenShader: Track Application Variables", "[genshader]")

mx::DocumentPtr testDoc = mx::createDocument();
mx::readFromXmlString(testDoc, testDocumentString);
testDoc->importLibrary(libraries);
testDoc->registerDataLibrary(libraries);

mx::ElementPtr element = testDoc->getChild(testElement);
CHECK(element);
Expand Down
4 changes: 3 additions & 1 deletion source/MaterialXTest/MaterialXGenShader/GenShaderUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ void shaderGenPerformanceTest(mx::GenContext& context)
std::shuffle(loadedDocuments.begin(), loadedDocuments.end(), rng);
for (const auto& doc : loadedDocuments)
{
doc->importLibrary(nodeLibrary);
doc->registerDataLibrary(nodeLibrary);
std::vector<mx::TypedElementPtr> elements = mx::findRenderableElements(doc);

REQUIRE(elements.size() > 0);
Expand Down Expand Up @@ -721,6 +721,8 @@ void ShaderGeneratorTester::validate(const mx::GenOptions& generateOptions, cons
bool importedLibrary = false;
try
{
//TODO: Enable setDataLibrary and ensures all implementations are accounted for.
// doc->setDataLibrary(_dependLib);
doc->importLibrary(_dependLib);
importedLibrary = true;
}
Expand Down
2 changes: 1 addition & 1 deletion source/MaterialXTest/MaterialXRender/RenderUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ bool ShaderRenderTester::validate(const mx::FilePath optionsFilePath)
// colliding with implementations in previous test cases.
context.clearNodeImplementations();

doc->importLibrary(dependLib);
doc->registerDataLibrary(dependLib);
ioTimer.endTimer();

validateTimer.startTimer();
Expand Down
2 changes: 1 addition & 1 deletion source/MaterialXView/Viewer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1317,7 +1317,7 @@ void Viewer::loadDocument(const mx::FilePath& filename, mx::DocumentPtr librarie
_materialSearchPath = mx::getSourceSearchPath(doc);

// Import libraries.
doc->importLibrary(libraries);
doc->registerDataLibrary(libraries);

// Apply direct lights.
applyDirectLights(doc);
Expand Down
3 changes: 3 additions & 0 deletions source/PyMaterialX/PyMaterialXCore/PyDocument.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ void bindPyDocument(py::module& mod)
.def("initialize", &mx::Document::initialize)
.def("copy", &mx::Document::copy)
.def("importLibrary", &mx::Document::importLibrary)
.def("setDataLibrary", &mx::Document::registerDataLibrary)
.def("getDataLibrary", &mx::Document::getRegisteredDataLibrary)
.def("hasDataLibrary", &mx::Document::hasDataLibrary)
.def("getReferencedSourceUris", &mx::Document::getReferencedSourceUris)
.def("addNodeGraph", &mx::Document::addNodeGraph,
py::arg("name") = mx::EMPTY_STRING)
Expand Down

0 comments on commit 7e61f63

Please sign in to comment.