diff --git a/Bindings/OpenSimHeaders_common.h b/Bindings/OpenSimHeaders_common.h index 44c5f343c5..333b4d7a05 100644 --- a/Bindings/OpenSimHeaders_common.h +++ b/Bindings/OpenSimHeaders_common.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include diff --git a/Bindings/common.i b/Bindings/common.i index 0b84575d90..ec66faca3f 100644 --- a/Bindings/common.i +++ b/Bindings/common.i @@ -60,6 +60,7 @@ %include %include %include +%include %include %include diff --git a/CHANGELOG.md b/CHANGELOG.md index 2923ade477..a1a282a8d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ v4.6 - Added `Output`s to `ExpressionBasedCoordinateForce`, `ExpressionBasedPointToPointForce`, and `ExpressionBasedBushingForce` for accessing force values. (#3872) - `PointForceDirection` no longer has a virtual destructor, is `final`, and its `scale` functionality has been marked as `[[deprecated]]` (#3890) +- Added `ExpressionBasedFunction` for creating `Function`s based on user-defined mathematical expressions. (#3892) v4.5.1 ====== diff --git a/OpenSim/Common/CMakeLists.txt b/OpenSim/Common/CMakeLists.txt index ef0508c4b6..afabfd39de 100644 --- a/OpenSim/Common/CMakeLists.txt +++ b/OpenSim/Common/CMakeLists.txt @@ -16,7 +16,7 @@ OpenSimAddLibrary( KIT Common AUTHORS "Clay_Anderson-Ayman_Habib-Peter_Loan" # Clients of osimCommon need not link to ezc3d. - LINKLIBS PUBLIC ${Simbody_LIBRARIES} spdlog::spdlog + LINKLIBS PUBLIC ${Simbody_LIBRARIES} spdlog::spdlog osimLepton PRIVATE ${ezc3d_LIBRARY} INCLUDES ${INCLUDES} SOURCES ${SOURCES} diff --git a/OpenSim/Common/ExpressionBasedFunction.cpp b/OpenSim/Common/ExpressionBasedFunction.cpp new file mode 100644 index 0000000000..63710eb885 --- /dev/null +++ b/OpenSim/Common/ExpressionBasedFunction.cpp @@ -0,0 +1,127 @@ +/* -------------------------------------------------------------------------- * + * OpenSim: ExpressionBasedFunction.cpp * + * -------------------------------------------------------------------------- * + * The OpenSim API is a toolkit for musculoskeletal modeling and simulation. * + * See http://opensim.stanford.edu and the NOTICE file for more information. * + * OpenSim is developed at Stanford University and supported by the US * + * National Institutes of Health (U54 GM072970, R24 HD065690) and by DARPA * + * through the Warrior Web program. * + * * + * Copyright (c) 2005-2024 Stanford University and the Authors * + * Author(s): Nicholas Bianco * + * * + * Licensed under the Apache License, Version 2.0 (the "License"); you may * + * not use this file except in compliance with the License. You may obtain a * + * copy of the License at http://www.apache.org/licenses/LICENSE-2.0. * + * * + * Unless required by applicable law or agreed to in writing, software * + * distributed under the License is distributed on an "AS IS" BASIS, * + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * + * See the License for the specific language governing permissions and * + * limitations under the License. * + * -------------------------------------------------------------------------- */ + +#include "ExpressionBasedFunction.h" + +#include +#include +#include +#include + +using namespace OpenSim; + +class SimTKExpressionBasedFunction : public SimTK::Function { +public: + SimTKExpressionBasedFunction(const std::string& expression, + const std::vector& variables) : + m_expression(expression), m_variables(variables) { + + // Check that the variable names are unique. + std::set uniqueVariables; + for (const auto& variable : m_variables) { + if (!uniqueVariables.insert(variable).second) { + OPENSIM_THROW(Exception, + fmt::format("Variable '{}' is defined more than once.", + variable)); + } + } + + // Create the expression programs for the value and its derivatives. + Lepton::ParsedExpression parsedExpression = + Lepton::Parser::parse(m_expression).optimize(); + m_valueProgram = parsedExpression.createProgram(); + + for (int i = 0; i < static_cast(m_variables.size()); ++i) { + Lepton::ParsedExpression diffExpression = + parsedExpression.differentiate(m_variables[i]).optimize(); + m_derivativePrograms.push_back(diffExpression.createProgram()); + } + + try { + std::map vars; + for (int i = 0; i < static_cast(m_variables.size()); ++i) { + vars[m_variables[i]] = 0; + } + m_valueProgram.evaluate(vars); + + for (int i = 0; i < static_cast(m_variables.size()); ++i) { + m_derivativePrograms[i].evaluate(vars); + } + } catch (Lepton::Exception& ex) { + std::string msg = ex.what(); + std::string undefinedVar = msg.substr(32, msg.size() - 32); + OPENSIM_THROW(Exception, + fmt::format("Variable '{}' is not defined. Use " + "setVariables() to explicitly define this variable. Or, " + "remove it from the expression.", undefinedVar)); + } + } + + SimTK::Real calcValue(const SimTK::Vector& x) const override { + OPENSIM_ASSERT(x.size() == static_cast(m_variables.size())); + std::map vars; + for (int i = 0; i < static_cast(m_variables.size()); ++i) { + vars[m_variables[i]] = x[i]; + } + return m_valueProgram.evaluate(vars); + } + + SimTK::Real calcDerivative(const SimTK::Array_& derivComponents, + const SimTK::Vector& x) const override { + OPENSIM_ASSERT(x.size() == static_cast(m_variables.size())); + OPENSIM_ASSERT(derivComponents.size() == 1); + if (derivComponents[0] < static_cast(m_variables.size())) { + std::map vars; + for (int i = 0; i < static_cast(m_variables.size()); ++i) { + vars[m_variables[i]] = x[i]; + } + return m_derivativePrograms[derivComponents[0]].evaluate(vars); + } + return 0.0; + } + + int getArgumentSize() const override { + return static_cast(m_variables.size()); + } + int getMaxDerivativeOrder() const override { return 1; } + SimTKExpressionBasedFunction* clone() const override { + return new SimTKExpressionBasedFunction(*this); + } + +private: + std::string m_expression; + std::vector m_variables; + Lepton::ExpressionProgram m_valueProgram; + std::vector m_derivativePrograms; +}; + +SimTK::Function* ExpressionBasedFunction::createSimTKFunction() const { + OPENSIM_THROW_IF_FRMOBJ(get_expression().empty(), Exception, + "The expression has not been set. Use setExpression().") + + std::vector variables; + for (int i = 0; i < getProperty_variables().size(); ++i) { + variables.push_back(get_variables(i)); + } + return new SimTKExpressionBasedFunction(get_expression(), variables); +} \ No newline at end of file diff --git a/OpenSim/Common/ExpressionBasedFunction.h b/OpenSim/Common/ExpressionBasedFunction.h new file mode 100644 index 0000000000..b05d2dc70d --- /dev/null +++ b/OpenSim/Common/ExpressionBasedFunction.h @@ -0,0 +1,134 @@ +#ifndef OPENSIM_EXPRESSION_BASED_FUNCTION_H_ +#define OPENSIM_EXPRESSION_BASED_FUNCTION_H_ +/* -------------------------------------------------------------------------- * + * OpenSim: ExpressionBasedFunction.h * + * -------------------------------------------------------------------------- * + * The OpenSim API is a toolkit for musculoskeletal modeling and simulation. * + * See http://opensim.stanford.edu and the NOTICE file for more information. * + * OpenSim is developed at Stanford University and supported by the US * + * National Institutes of Health (U54 GM072970, R24 HD065690) and by DARPA * + * through the Warrior Web program. * + * * + * Copyright (c) 2005-2024 Stanford University and the Authors * + * Author(s): Nicholas Bianco * + * * + * Licensed under the Apache License, Version 2.0 (the "License"); you may * + * not use this file except in compliance with the License. You may obtain a * + * copy of the License at http://www.apache.org/licenses/LICENSE-2.0. * + * * + * Unless required by applicable law or agreed to in writing, software * + * distributed under the License is distributed on an "AS IS" BASIS, * + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * + * See the License for the specific language governing permissions and * + * limitations under the License. * + * -------------------------------------------------------------------------- */ + +#include "osimCommonDLL.h" +#include "Function.h" +namespace OpenSim { + +/** + * A function based on a user-defined mathematical expression. + * + * This class allows users to define a function based on a mathematical + * expression (e.g., "x*sqrt(y-8)"). The expression can be a function of any + * number of independent variables. The expression is parsed and evaluated using + * the Lepton library. + * + * Set the expression using setExpression(). Any variables used in the + * expression must be explicitly defined using setVariables(). This + * implementation allows computation of first-order derivatives only. + * + * # Creating Expressions + * + * Expressions can contain variables, constants, operations, parentheses, commas, + * spaces, and scientific "e" notation. The full list of supported operations is: + * sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, + * tanh, erf, erfc, step, delta, square, cube, recip, min, max, abs, +, -, *, /, + * and ^. + */ +class OSIMCOMMON_API ExpressionBasedFunction : public Function { + OpenSim_DECLARE_CONCRETE_OBJECT(ExpressionBasedFunction, Function); + +public: +//============================================================================== +// PROPERTIES +//============================================================================== + OpenSim_DECLARE_PROPERTY(expression, std::string, + "The mathematical expression defining this Function."); + OpenSim_DECLARE_LIST_PROPERTY(variables, std::string, + "The independent variables used by this Function's expression. " + "In XML, variable names should be space-separated."); + +//============================================================================== +// METHODS +//============================================================================== + + /** Default constructor. */ + ExpressionBasedFunction() { constructProperties(); } + + /** Convenience constructor. + * + * @param expression The expression that defines this Function. + * @param variables The independent variable names of this expression. + */ + ExpressionBasedFunction(std::string expression, + const std::vector& variables) { + constructProperties(); + set_expression(std::move(expression)); + setVariables(variables); + } + + /** + * The mathematical expression that defines this Function. The expression + * should be a function of the variables defined via setVariables(). + * + * @note The expression cannot contain any whitespace characters. + */ + void setExpression(std::string expression) { + set_expression(std::move(expression)); + } + /// @copydoc setExpression() + const std::string& getExpression() const { + return get_expression(); + } + + /** + * The independent variable names of this expression. The variables names + * should be unique and should be comprised of alphabetic characters or any + * characters not reserved by Lepton (i.e., +, -, *, /, and ^). Variable + * names can contain numbers as long they do not come first in the name + * (e.g., "var0"). The input vector passed to calcValue() and + * calcDerivative() should be in the same order as the variables defined + * here. + */ + void setVariables(const std::vector& variables) { + for (const auto& var : variables) { + append_variables(var); + } + } + /// @copydoc setVariables() + std::vector getVariables() const { + std::vector variables; + for (int i = 0; i < getProperty_variables().size(); ++i) { + variables.push_back(get_variables(i)); + } + return variables; + } + + /** + * Return a pointer to a SimTK::Function object that implements this + * function. + */ + SimTK::Function* createSimTKFunction() const override; + +private: + void constructProperties() { + constructProperty_expression(""); + constructProperty_variables(); + } +}; + +} // namespace OpenSim + +#endif // OPENSIM_EXPRESSION_BASED_FUNCTION_H_ \ No newline at end of file diff --git a/OpenSim/Common/RegisterTypes_osimCommon.cpp b/OpenSim/Common/RegisterTypes_osimCommon.cpp index 2be4dec912..114a4fde20 100644 --- a/OpenSim/Common/RegisterTypes_osimCommon.cpp +++ b/OpenSim/Common/RegisterTypes_osimCommon.cpp @@ -41,6 +41,7 @@ #include "MultiplierFunction.h" #include "PolynomialFunction.h" #include "MultivariatePolynomialFunction.h" +#include "ExpressionBasedFunction.h" #include "SignalGenerator.h" @@ -88,6 +89,7 @@ OSIMCOMMON_API void RegisterTypes_osimCommon() Object::registerType( MultiplierFunction() ); Object::registerType( PolynomialFunction() ); Object::registerType( MultivariatePolynomialFunction() ); + Object::registerType( ExpressionBasedFunction() ); Object::registerType( SignalGenerator() ); diff --git a/OpenSim/Common/Test/testFunctions.cpp b/OpenSim/Common/Test/testFunctions.cpp index 69fb85d81f..d854c914c9 100644 --- a/OpenSim/Common/Test/testFunctions.cpp +++ b/OpenSim/Common/Test/testFunctions.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include @@ -260,3 +261,92 @@ TEST_CASE("solveBisection()") { REQUIRE_THROWS_AS(solveBisection(parabola, -5, 5), OpenSim::Exception); } } + +TEST_CASE("ExpressionBasedFunction") { + const SimTK::Real x = SimTK::Test::randReal(); + const SimTK::Real y = SimTK::Test::randReal(); + const SimTK::Real z = SimTK::Test::randReal(); + + SECTION("Square-root function") { + ExpressionBasedFunction f("sqrt(x)", {"x"}); + REQUIRE_THAT(f.calcValue(createVector({x})), + Catch::Matchers::WithinAbs(std::sqrt(x), 1e-10)); + REQUIRE_THAT(f.calcDerivative({0}, createVector({x})), + Catch::Matchers::WithinAbs(0.5 / std::sqrt(x), 1e-10)); + } + + SECTION("Exponential function") { + ExpressionBasedFunction f("exp(x)", {"x"}); + REQUIRE_THAT(f.calcValue(createVector({x})), + Catch::Matchers::WithinAbs(std::exp(x), 1e-10)); + REQUIRE_THAT(f.calcDerivative({0}, createVector({x})), + Catch::Matchers::WithinAbs(std::exp(x), 1e-10)); + } + + SECTION("Multivariate function") { + ExpressionBasedFunction f("2*x^3 + 3*y*z^2", {"x", "y", "z"}); + REQUIRE_THAT(f.calcValue(createVector({x, y, z})), + Catch::Matchers::WithinAbs(2*x*x*x + 3*y*z*z, 1e-10)); + REQUIRE_THAT(f.calcDerivative({0}, createVector({x, y, z})), + Catch::Matchers::WithinAbs(6*x*x, 1e-10)); + REQUIRE_THAT(f.calcDerivative({1}, createVector({x, y, z})), + Catch::Matchers::WithinAbs(3*z*z, 1e-10)); + REQUIRE_THAT(f.calcDerivative({2}, createVector({x, y, z})), + Catch::Matchers::WithinAbs(6*y*z, 1e-10)); + } + + + SECTION("Sinusoidal function") { + ExpressionBasedFunction f("x*sin(y*z^2)", {"x", "y", "z"}); + REQUIRE_THAT(f.calcValue(createVector({x, y, z})), + Catch::Matchers::WithinAbs(x * std::sin(y*z*z), 1e-10)); + REQUIRE_THAT(f.calcDerivative({0}, createVector({x, y, z})), + Catch::Matchers::WithinAbs(std::sin(y*z*z), 1e-10)); + REQUIRE_THAT(f.calcDerivative({1}, createVector({x, y, z})), + Catch::Matchers::WithinAbs(x*z*z*std::cos(y*z*z), 1e-10)); + } + + SECTION("Undefined variable in expression") { + ExpressionBasedFunction f("x*y", {"x"}); + REQUIRE_THROWS_WITH(f.calcValue(createVector({x, y})), + Catch::Matchers::ContainsSubstring( + "Variable 'y' is not defined.")); + } + + SECTION("Extra variables should have zero derivative") { + ExpressionBasedFunction f("x*y", {"x", "y", "z"}); + REQUIRE_THAT(f.calcValue(createVector({x, y, z})), + Catch::Matchers::WithinAbs(x*y, 1e-10)); + REQUIRE_THAT(f.calcDerivative({0}, createVector({x, y, z})), + Catch::Matchers::WithinAbs(y, 1e-10)); + REQUIRE_THAT(f.calcDerivative({1}, createVector({x, y, z})), + Catch::Matchers::WithinAbs(x, 1e-10)); + REQUIRE_THAT(f.calcDerivative({2}, createVector({x, y, z})), + Catch::Matchers::WithinAbs(0.0, 1e-10)); + } + + SECTION("Derivative of nonexistent variable") { + ExpressionBasedFunction f("x*y", {"x", "y"}); + REQUIRE_THAT(f.calcDerivative({2}, createVector({x, y})), + Catch::Matchers::WithinAbs(0.0, 1e-10)); + } + + SECTION("Variable defined multiple times") { + ExpressionBasedFunction f("x", {"x", "x"}); + REQUIRE_THROWS_WITH(f.calcValue(createVector({x})), + Catch::Matchers::ContainsSubstring( + "Variable 'x' is defined more than once.")); + } + + SECTION("Non-alphabetic variable names") { + ExpressionBasedFunction f("@^2 + %*cos(&)", {"@", "%", "&"}); + REQUIRE_THAT(f.calcValue(createVector({x, y, z})), + Catch::Matchers::WithinAbs(x*x + y*std::cos(z), 1e-10)); + REQUIRE_THAT(f.calcDerivative({0}, createVector({x, y, z})), + Catch::Matchers::WithinAbs(2*x, 1e-10)); + REQUIRE_THAT(f.calcDerivative({1}, createVector({x, y, z})), + Catch::Matchers::WithinAbs(std::cos(z), 1e-10)); + REQUIRE_THAT(f.calcDerivative({2}, createVector({x, y, z})), + Catch::Matchers::WithinAbs(-y*std::sin(z), 1e-10)); + } +} \ No newline at end of file diff --git a/OpenSim/Common/osimCommon.h b/OpenSim/Common/osimCommon.h index 23cd153237..19adf1e7bb 100644 --- a/OpenSim/Common/osimCommon.h +++ b/OpenSim/Common/osimCommon.h @@ -29,6 +29,7 @@ #include "CommonUtilities.h" #include "Constant.h" #include "DataTable.h" +#include "ExpressionBasedFunction.h" #include "FunctionSet.h" #include "GCVSpline.h" #include "GCVSplineSet.h"