Skip to content

Commit

Permalink
Generator: take into account tracked/untracked variables when generat…
Browse files Browse the repository at this point in the history
…ing some code.
  • Loading branch information
agarny committed Oct 27, 2024
1 parent 723d025 commit b2748c5
Show file tree
Hide file tree
Showing 22 changed files with 1,810 additions and 36 deletions.
5 changes: 4 additions & 1 deletion src/api/libcellml/analyser.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class LIBCELLML_EXPORT Analyser: public Logger
* @brief Add a @ref VariablePtr as an external variable to this @ref Analyser.
*
* Add the given @ref VariablePtr as an external variable to this @ref Analyser, but only if it has not already been
* added.
* added. Please note that it is your responsibility to ensure that all the variables on which an external variable
* depends are tracked.
*
* @param variable The @ref Variable to add as an external variable.
*
Expand All @@ -76,6 +77,8 @@ class LIBCELLML_EXPORT Analyser: public Logger
* @brief Add an @ref AnalyserExternalVariable to this @ref Analyser.
*
* Add the given @ref AnalyserExternalVariable to this @ref Analyser, but only if it has not already been added.
* Please note that it is your responsibility to ensure that all the variables on which an external variable depends
* are tracked.
*
* @param externalVariable The @ref AnalyserExternalVariable to add.
*
Expand Down
63 changes: 41 additions & 22 deletions src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1853,16 +1853,23 @@ std::string Generator::GeneratorImpl::generateCode(const AnalyserModelPtr &model
code = generateCode(model, ast->leftChild());

break;
case AnalyserEquationAst::Type::CI:
case AnalyserEquationAst::Type::CI: {
code = generateVariableNameCode(model, ast->variable(), ast->parent()->type() != AnalyserEquationAst::Type::DIFF);

auto astParent = ast->parent();

if ((model != nullptr)
&& (ast->parent()->type() == AnalyserEquationAst::Type::EQUALITY)
&& (astParent->type() == AnalyserEquationAst::Type::EQUALITY)
&& (astParent->leftChild() == ast)
&& isUntrackedVariable(model->variable(ast->variable()))) {
// Note: we want this AST to be its parent's left child since a declaration is always of the form x = RHS,
// not LHS = x.

code = replace(mProfile->variableDeclarationString(), "[CODE]", code);
}
}

break;
break;
case AnalyserEquationAst::Type::CN:
code = generateDoubleCode(ast->value());

Expand Down Expand Up @@ -2012,11 +2019,13 @@ std::string Generator::GeneratorImpl::generateEquationCode(const AnalyserModelPt

if (!isSomeConstant(equation, includeComputedConstants)) {
for (const auto &dependency : equation->dependencies()) {
if ((dependency->type() != AnalyserEquation::Type::ODE)
&& !isSomeConstant(dependency, includeComputedConstants)
&& (equationsForDependencies.empty()
|| isToBeComputedAgain(dependency)
|| (std::find(equationsForDependencies.begin(), equationsForDependencies.end(), dependency) != equationsForDependencies.end()))) {
if (((dependency->type() == AnalyserEquation::Type::COMPUTED_CONSTANT)
&& isUntrackedVariable(dependency->computedConstants().front()))
|| ((dependency->type() != AnalyserEquation::Type::ODE)
&& !isSomeConstant(dependency, includeComputedConstants)
&& (equationsForDependencies.empty()
|| isToBeComputedAgain(dependency)
|| (std::find(equationsForDependencies.begin(), equationsForDependencies.end(), dependency) != equationsForDependencies.end())))) {
res += generateEquationCode(model, dependency, remainingEquations, equationsForDependencies,
generatedConstantDependencies, includeComputedConstants);
}
Expand All @@ -2028,12 +2037,16 @@ std::string Generator::GeneratorImpl::generateEquationCode(const AnalyserModelPt
switch (equation->type()) {
case AnalyserEquation::Type::EXTERNAL:
for (const auto &variable : variables(equation)) {
auto code = generateVariableNameCode(model, variable->variable())
+ mProfile->equalityString()
+ replace(mProfile->externalVariableMethodCallString(modelHasOdes(model)),
"[INDEX]", convertToString(variable->index()))
+ mProfile->commandSeparatorString() + "\n";

code = replace(mProfile->variableDeclarationString(), "[CODE]", code);

res += mProfile->indentString()
+ generateVariableNameCode(model, variable->variable())
+ mProfile->equalityString()
+ replace(mProfile->externalVariableMethodCallString(modelHasOdes(model)),
"[INDEX]", convertToString(variable->index()))
+ mProfile->commandSeparatorString() + "\n";
+ code;
}

break;
Expand All @@ -2057,10 +2070,10 @@ std::string Generator::GeneratorImpl::generateEquationCode(const AnalyserModelPt

std::string Generator::GeneratorImpl::generateEquationCode(const AnalyserModelPtr &model,
const AnalyserEquationPtr &equation,
std::vector<AnalyserEquationPtr> &remainingEquations)
std::vector<AnalyserEquationPtr> &remainingEquations,
std::vector<AnalyserVariablePtr> &generatedConstantDependencies)
{
std::vector<AnalyserEquationPtr> dummyEquationsForComputeVariables;
std::vector<AnalyserVariablePtr> generatedConstantDependencies;

return generateEquationCode(model, equation, remainingEquations, dummyEquationsForComputeVariables,
generatedConstantDependencies, true);
Expand Down Expand Up @@ -2173,11 +2186,11 @@ void Generator::GeneratorImpl::addImplementationInitialiseVariablesMethodCode(co
// Initialise our computed constants that are initialised using an equation (e.g., x = 3 rather than x with an
// initial value of 3).

auto equations = model->equations();
std::vector<AnalyserVariablePtr> generatedConstantDependencies;

for (const auto &equation : equations) {
for (const auto &equation : model->equations()) {
if (equation->type() == AnalyserEquation::Type::CONSTANT) {
methodBody += generateEquationCode(model, equation, remainingEquations);
methodBody += generateEquationCode(model, equation, remainingEquations, generatedConstantDependencies);
}
}

Expand Down Expand Up @@ -2207,11 +2220,12 @@ void Generator::GeneratorImpl::addImplementationComputeComputedConstantsMethodCo
{
if (!mProfile->implementationComputeComputedConstantsMethodString().empty()) {
std::string methodBody;
std::vector<AnalyserVariablePtr> generatedConstantDependencies;

for (const auto &equation : model->equations()) {
if ((equation->type() == AnalyserEquation::Type::COMPUTED_CONSTANT)
&& isTrackedVariable(equation->computedConstants().front())) {
methodBody += generateEquationCode(model, equation, remainingEquations);
methodBody += generateEquationCode(model, equation, remainingEquations, generatedConstantDependencies);
}
}

Expand All @@ -2229,6 +2243,7 @@ void Generator::GeneratorImpl::addImplementationComputeRatesMethodCode(const Ana
if (modelHasOdes(model)
&& !implementationComputeRatesMethodString.empty()) {
std::string methodBody;
std::vector<AnalyserVariablePtr> generatedConstantDependencies;

for (const auto &equation : model->equations()) {
// A rate is computed either through an ODE equation or through an
Expand All @@ -2241,7 +2256,7 @@ void Generator::GeneratorImpl::addImplementationComputeRatesMethodCode(const Ana
|| ((equation->type() == AnalyserEquation::Type::NLA)
&& (variables.size() == 1)
&& (variables[0]->type() == AnalyserVariable::Type::STATE))) {
methodBody += generateEquationCode(model, equation, remainingEquations);
methodBody += generateEquationCode(model, equation, remainingEquations, generatedConstantDependencies);
}
}

Expand All @@ -2264,8 +2279,12 @@ void Generator::GeneratorImpl::addImplementationComputeVariablesMethodCode(const
std::vector<AnalyserVariablePtr> generatedConstantDependencies;

for (const auto &equation : equations) {
if ((std::find(remainingEquations.begin(), remainingEquations.end(), equation) != remainingEquations.end())
|| isToBeComputedAgain(equation)) {
if (((std::find(remainingEquations.begin(), remainingEquations.end(), equation) != remainingEquations.end())
|| isToBeComputedAgain(equation))
&& (((equation->type() == AnalyserEquation::Type::ALGEBRAIC)
&& isTrackedVariable(equation->algebraic().front()))
|| ((equation->type() == AnalyserEquation::Type::EXTERNAL)
&& isTrackedVariable(equation->externals().front())))) {
methodBody += generateEquationCode(model, equation, newRemainingEquations, remainingEquations,
generatedConstantDependencies, false);
}
Expand Down
3 changes: 2 additions & 1 deletion src/generator_p.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ struct Generator::GeneratorImpl
std::vector<AnalyserVariablePtr> &generatedConstantDependencies,
bool includeComputedConstants);
std::string generateEquationCode(const AnalyserModelPtr &model, const AnalyserEquationPtr &equation,
std::vector<AnalyserEquationPtr> &remainingEquations);
std::vector<AnalyserEquationPtr> &remainingEquations,
std::vector<AnalyserVariablePtr> &generatedConstantDependencies);

void addInterfaceComputeModelMethodsCode(const AnalyserModelPtr &model);
std::string generateConstantInitialisationCode(const AnalyserModelPtr &model,
Expand Down
17 changes: 17 additions & 0 deletions tests/coverage/coverage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,23 @@ TEST(Coverage, generator)
libcellml::Generator::equationCode(analyser->model()->equation(0)->ast());
}

TEST(Coverage, generatorWithNoTracking)
{
auto parser = libcellml::Parser::create();
auto model = parser->parseModel(fileContents("coverage/generator/model.cellml"));
auto analyser = libcellml::Analyser::create();

analyser->analyseModel(model);

auto analyserModel = analyser->model();
auto generator = libcellml::Generator::create();

generator->untrackAllVariables(analyserModel);

EXPECT_EQ_FILE_CONTENTS("coverage/generator/model.no.tracking.h", generator->interfaceCode(analyserModel));
EXPECT_EQ_FILE_CONTENTS("coverage/generator/model.no.tracking.c", generator->implementationCode(analyserModel));
}

TEST(CoverageValidator, degreeElementWithOneSibling)
{
const std::string math =
Expand Down
140 changes: 129 additions & 11 deletions tests/generator/generatortrackedvariables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ TEST(GeneratorTrackedVariables, trackAndUntrackAllVariables)
EXPECT_EQ(size_t(0), generator->untrackedVariableCount(analyserModel));
}

TEST(GeneratorTrackedVariables, hodgkinHuxleySquidAxonModel1952NoTracking)
TEST(GeneratorTrackedVariables, hodgkinHuxleySquidAxonModel1952UntrackedVariables)
{
auto parser = libcellml::Parser::create();
auto model = parser->parseModel(fileContents("generator/hodgkin_huxley_squid_axon_model_1952/model.cellml"));
Expand All @@ -452,24 +452,142 @@ TEST(GeneratorTrackedVariables, hodgkinHuxleySquidAxonModel1952NoTracking)

generator->untrackAllVariables(analyserModel);

EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.no.tracking.h", generator->interfaceCode(analyserModel));
EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.no.tracking.c", generator->implementationCode(analyserModel));
auto profile = generator->profile();

auto profile = libcellml::GeneratorProfile::create(libcellml::GeneratorProfile::Profile::PYTHON);
profile->setInterfaceFileNameString("model.untracked.variables.h");

EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.variables.h", generator->interfaceCode(analyserModel));
EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.variables.c", generator->implementationCode(analyserModel));

profile = libcellml::GeneratorProfile::create(libcellml::GeneratorProfile::Profile::PYTHON);

generator->setProfile(profile);

EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.variables.py", generator->implementationCode(analyserModel));

// With some external variables.

auto potassium_channel_n_gate_alpha_n = model->component("potassium_channel_n_gate")->variable("alpha_n");
auto external_sodium_channel_i_Na = libcellml::AnalyserExternalVariable::create(model->component("sodium_channel")->variable("i_Na"));

external_sodium_channel_i_Na->addDependency(potassium_channel_n_gate_alpha_n);
external_sodium_channel_i_Na->addDependency(model->component("sodium_channel_h_gate")->variable("h"));

analyser->addExternalVariable(libcellml::AnalyserExternalVariable::create(model->component("membrane")->variable("V")));
analyser->addExternalVariable(external_sodium_channel_i_Na);
analyser->addExternalVariable(libcellml::AnalyserExternalVariable::create(potassium_channel_n_gate_alpha_n));

analyser->analyseModel(model);

analyserModel = analyser->model();

generator->untrackAllVariables(analyserModel);

profile = libcellml::GeneratorProfile::create(libcellml::GeneratorProfile::Profile::C);

generator->setProfile(profile);

profile->setInterfaceFileNameString("model.untracked.variables.with.externals.h");

EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.variables.with.externals.h", generator->interfaceCode(analyserModel));
EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.variables.with.externals.c", generator->implementationCode(analyserModel));

profile = libcellml::GeneratorProfile::create(libcellml::GeneratorProfile::Profile::PYTHON);

generator->setProfile(profile);

EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.variables.with.externals.py", generator->implementationCode(analyserModel));
}

TEST(GeneratorTrackedVariables, hodgkinHuxleySquidAxonModel1952UntrackedConstants)
{
auto parser = libcellml::Parser::create();
auto model = parser->parseModel(fileContents("generator/hodgkin_huxley_squid_axon_model_1952/model.cellml"));
auto analyser = libcellml::Analyser::create();

analyser->analyseModel(model);

auto analyserModel = analyser->model();
auto generator = libcellml::Generator::create();

generator->untrackAllConstants(analyserModel);

auto profile = generator->profile();

profile->setInterfaceFileNameString("model.untracked.constants.h");

EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.constants.h", generator->interfaceCode(analyserModel));
EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.constants.c", generator->implementationCode(analyserModel));

profile = libcellml::GeneratorProfile::create(libcellml::GeneratorProfile::Profile::PYTHON);

generator->setProfile(profile);

EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.constants.py", generator->implementationCode(analyserModel));
}

TEST(GeneratorTrackedVariables, hodgkinHuxleySquidAxonModel1952UntrackedComputedConstants)
{
auto parser = libcellml::Parser::create();
auto model = parser->parseModel(fileContents("generator/hodgkin_huxley_squid_axon_model_1952/model.cellml"));
auto analyser = libcellml::Analyser::create();

analyser->analyseModel(model);

auto analyserModel = analyser->model();
auto generator = libcellml::Generator::create();

generator->untrackAllComputedConstants(analyserModel);

auto profile = generator->profile();

profile->setInterfaceFileNameString("model.untracked.computed.constants.h");

EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.computed.constants.h", generator->interfaceCode(analyserModel));
EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.computed.constants.c", generator->implementationCode(analyserModel));

profile = libcellml::GeneratorProfile::create(libcellml::GeneratorProfile::Profile::PYTHON);

generator->setProfile(profile);

EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.computed.constants.py", generator->implementationCode(analyserModel));
}

TEST(GeneratorTrackedVariables, hodgkinHuxleySquidAxonModel1952UntrackedAlgebraicVariables)
{
auto parser = libcellml::Parser::create();
auto model = parser->parseModel(fileContents("generator/hodgkin_huxley_squid_axon_model_1952/model.cellml"));
auto analyser = libcellml::Analyser::create();

analyser->analyseModel(model);

auto analyserModel = analyser->model();
auto generator = libcellml::Generator::create();

generator->untrackAllAlgebraic(analyserModel);

auto profile = generator->profile();

profile->setInterfaceFileNameString("model.untracked.algebraic.variables.h");

EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.algebraic.variables.h", generator->interfaceCode(analyserModel));
EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.algebraic.variables.c", generator->implementationCode(analyserModel));

profile = libcellml::GeneratorProfile::create(libcellml::GeneratorProfile::Profile::PYTHON);

generator->setProfile(profile);

EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.no.tracking.py", generator->implementationCode(analyserModel));
EXPECT_EQ_FILE_CONTENTS("generator/hodgkin_huxley_squid_axon_model_1952/model.untracked.algebraic.variables.py", generator->implementationCode(analyserModel));
}

/**
* Need the following tests:
* - ODE HH52 + No tracking
* - ODE HH52 + No tracking + Externals (with some dependencies on untracked variables)
* - ODE HH52 + Some tracking (with some dependencies on untracked variables)
* - ODE HH52 + Some tracking (with some dependencies on untracked variables) + Externals (with some dependencies on untracked variables)
* - ODE HH52 + No tracking -- DONE
* - ODE HH52 + No tracking + Externals -- DONE
* - ODE HH52 + Some tracking (with some dependencies on untracked variables) -- DONE
* - ODE HH52 + Some tracking (with some dependencies on untracked variables) + Externals
* - DAE HH52 + No tracking
* - DAE HH52 + No tracking + Externals (with some dependencies on untracked variables)
* - DAE HH52 + No tracking + Externals
* - DAE HH52 + Some tracking (with some dependencies on untracked variables)
* - DAE HH52 + Some tracking (with some dependencies on untracked variables) + Externals (with some dependencies on untracked variables)
* - DAE HH52 + Some tracking (with some dependencies on untracked variables) + Externals
*/
Loading

0 comments on commit b2748c5

Please sign in to comment.