Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Homogenize Python Bindings for DirichletValues + Support for fixing Subspace Basis #305

Merged
merged 19 commits into from
Jul 18, 2024
Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ SPDX-License-Identifier: LGPL-3.0-or-later
([#304](https://github.com/ikarus-project/ikarus/pull/304))
- This can be used, for instance, to apply concentrated forces or to add spring stiffness in a particular direction.
- Furthermore, a helper function to get the global index of a Lagrange node at the given global position is added.
- Rework Python Interface for `DirichletValues` plus adding support to easily fix boundary DOFs of subspacebasis in C++ and Python ([#305](https://github.com/ikarus-project/ikarus/pull/305))
henrij22 marked this conversation as resolved.
Show resolved Hide resolved
- Rework the Python Interface for `DirichletValues` plus add support to easily fix boundary DOFs of `Subspacebasis` in C++ and Python ([#305](https://github.com/ikarus-project/ikarus/pull/305))

## Release v0.4 (Ganymede)

Expand Down
16 changes: 9 additions & 7 deletions docs/website/01_framework/dirichletBCs.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,25 @@ The interface of `#!cpp Ikarus::DirichletValues` is represented by the following
Ikarus::DirichletValues dirichletValues2(basis); // (1)!
void fixBoundaryDOFs(f); // (2)!
void fixDOFs(f); // (3)!
void fixIthDOF(i); // (4)!
void setSingleDOF(i, flag); // (4)!
const auto& basis() const; // (5)!
bool isConstrained(std::size_t i) const; // (6)!
bool isConstrained(i) const; // (6)!
auto fixedDOFsize() const; // (7)!
auto size() const ; // (8)!
auto reset(); // (9)!
```

1. Create class by inserting a global basis, [@sander2020dune] Chapter 10.
2. Accepts a functor to fix boundary degrees of freedom. `f` is a functor that will be called with the boolean vector of fixed boundary.
2. Accepts a functor to fix boundary degrees of freedom. `f` is a functor that will be called with the Boolean vector of fixed boundary.
degrees of freedom and the usual arguments of `Dune::Functions::forEachBoundaryDOF`, as defined on page 388 of the Dune
[@sander2020dune] book.
3. A more general version of `fixBoundaryDOFs`. Here, a functor is to be provided that accepts a basis and the corresponding boolean
4. A function that helps to fix the $i$-th degree of freedom
3. A more general version of `fixBoundaryDOFs`. Here, a functor is to be provided that accepts a basis and the corresponding Boolean
4. A function that helps to fix or unfix the $i$-th degree of freedom
vector considering the Dirichlet degrees of freedom.
5. Returns the underlying basis.
6. Indicates whether the degree of freedom `i` is fixed.
6. Indicates whether the degree of freedom $i$ is fixed.
7. Returns the number of fixed degrees of freedom.
8. Returns the number of all dirichlet degrees of freedom.
8. Returns the number of all Dirichlet degrees of freedom.
9. Resets the whole container

\bibliography
185 changes: 131 additions & 54 deletions ikarus/python/dirichletvalues/dirichletvalues.hh
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,133 @@

#pragma once

#include <cstdlib>
#include <string>

#include "dune/common/classname.hh"
#include <dune/functions/functionspacebases/lagrangebasis.hh>
#include <dune/functions/functionspacebases/powerbasis.hh>
#include <dune/grid/yaspgrid.hh>
#include <dune/python/common/typeregistry.hh>
#include <dune/python/functions/globalbasis.hh>
#include <dune/python/functions/subspacebasis.hh>
#include <dune/python/pybind11/eigen.h>
#include <dune/python/pybind11/functional.h>
#include <dune/python/pybind11/pybind11.h>
#include <dune/python/pybind11/stl.h>
#include <dune/python/pybind11/stl_bind.h>

#include <ikarus/finiteelements/ferequirements.hh>

// PYBIND11_MAKE_OPAQUE(std::vector<bool>);
namespace Ikarus::Python {

namespace Impl {
using FixBoundaryDOFsWithGlobalIndexFunction = std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int)>;

template <typename LV>
using FixBoundaryDOFsWithLocalViewFunction = std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int, LV&)>;

template <typename LV, typename IS>
using FixBoundaryDOFsWithIntersectionFunction =
std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int, LV&, const IS&)>;

template <typename Basis>
auto registerSubSpaceLocalView() {
pybind11::module scopedf = pybind11::module::import("dune.functions");
using LocalViewWrapper = Dune::Python::LocalViewWrapper<Basis>;

auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};

Dune::Python::insertClass<Basis>(scopedf, "SubspaceBasis_" + Dune::className<typename Basis::PrefixPath>(),
rath3t marked this conversation as resolved.
Show resolved Hide resolved
Dune::Python::GenerateTypeName(Dune::className<Basis>()), includes);

auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
scopedf, "LocalView_" + Dune::className<typename Basis::PrefixPath>(),
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapper", Dune::MetaType<Basis>()), includes);
if (isNew) {
lv.def("bind", &LocalViewWrapper::bind);
lv.def("unbind", &LocalViewWrapper::unbind);
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });

Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
}
}
} // namespace Impl

template <class DirichletValues>
void forwardCorrectFunction(DirichletValues& dirichletValues, const pybind11::function& functor, auto&& cppFunction) {
using Basis = typename DirichletValues::Basis;
using Intersection = typename Basis::GridView::Intersection;
using BackendType = typename DirichletValues::BackendType;
using MultiIndex = typename Basis::MultiIndex;

// Disambiguate by number of arguments
pybind11::module inspect_module = pybind11::module::import("inspect");
pybind11::object result = inspect_module.attr("signature")(functor).attr("parameters");
size_t numParams = pybind11::len(result);

if (numParams == 2) {
auto function = functor.template cast<const Impl::FixBoundaryDOFsWithGlobalIndexFunction>();
auto lambda = [&](BackendType& vec, const MultiIndex& indexGlobal) { function(vec.vector(), indexGlobal); };
cppFunction(lambda);

} else if (numParams == 3) {
auto lambda = [&](BackendType& vec, int localIndex, auto&& lv) {
using SubSpaceBasis = typename std::remove_cvref_t<decltype(lv)>::GlobalBasis;
Impl::registerSubSpaceLocalView<SubSpaceBasis>();

using SubSpaceLocalViewWrapper = Dune::Python::LocalViewWrapper<SubSpaceBasis>;
auto lvWrapper = SubSpaceLocalViewWrapper(lv);

auto function =
functor.template cast<const Impl::FixBoundaryDOFsWithLocalViewFunction<SubSpaceLocalViewWrapper>>();
function(vec.vector(), localIndex, lvWrapper);
};
cppFunction(lambda);

} else if (numParams == 4) {
auto lambda = [&](BackendType& vec, int localIndex, auto&& lv, const Intersection& intersection) {
using SubSpaceBasis = typename std::remove_cvref_t<decltype(lv)>::GlobalBasis;
Impl::registerSubSpaceLocalView<SubSpaceBasis>();

using SubSpaceLocalViewWrapper = Dune::Python::LocalViewWrapper<SubSpaceBasis>;
auto lvWrapper = SubSpaceLocalViewWrapper(lv);

auto function = functor.template cast<
const Impl::FixBoundaryDOFsWithIntersectionFunction<SubSpaceLocalViewWrapper, Intersection>>();
function(vec.vector(), localIndex, lvWrapper, intersection);
};
cppFunction(lambda);

} else {
DUNE_THROW(Dune::NotImplemented, "fixBoundaryDOFs: A function with this signature is not supported");
}
}

/**
* \brief Register Python bindings for a DirichletValues class.
*
* This function registers Python bindings for a DirichletValues class, allowing it to be used in Python scripts.
* The registered class will have an initializer that takes a `Basis` object. It exposes several member functions to
* Python:
* - `fixBoundaryDOFs(f)`: Fixes boundary degrees of freedom using a user-defined function `f`.
* - `fixBoundaryDOFsUsingLocalView(f)`: Fixes boundary degrees of freedom using a user-defined function `f` with a
* `LocalView` argument.
* - `fixBoundaryDOFsUsingLocalViewAndIntersection(f)`: Fixes boundary degrees of freedom using a user-defined
* function `f` with `LocalView` and `Intersection` arguments.
* - `fixDOFs(f)`: Fixes boundary degrees of freedom using a user-defined function `f` with the boolean vector and
* the basis as arguments.
* - `fixBoundaryDOFs(f)`: Fixes boundary degrees of freedom using a user-defined function `f` than can be called
* with the following arguments:
* - with the boolean vector and the global index.
* - with the boolean vector, the local index and the `LocalView`.
* - with the boolean vector, the local index, the `LocalView` and the `Intersection`.
* - `fixDOFs(f)`: Fixes boundary degrees of freedom using a user-defined function `f` with the basis and the boolean
* vector as arguments.
* - `setSingleDOF(i, flag: bool): Fixes or unfixes DOF with index i
* - `isConstrained(i)`: Checks whether index i is constrained
* - `reset()`: Resets the whole container
*
* The following properties can be accessed:
* - `container`: the underlying container of dirichlet flags (as a const reference)
* - `size`: the size of the underlying basis
* - `fixedDOFsize`: the amount of DOFs currently fixed
*
* \tparam DirichletValues The DirichletValues class to be registered.
* \tparam options Variadic template parameters for additional options when defining the Python class.
Expand All @@ -57,60 +156,38 @@ void registerDirichletValues(pybind11::handle scope, pybind11::class_<DirichletV
using Intersection = typename Basis::GridView::Intersection;

pybind11::module scopedf = pybind11::module::import("dune.functions");
typedef Dune::Python::LocalViewWrapper<Basis> LocalViewWrapper;
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
auto lv = Dune::Python::insertClass<LocalViewWrapper>(
scopedf, "LocalView",
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapper", Dune::MetaType<Basis>()), includes)
.first;
using LocalViewWrapper = Dune::Python::LocalViewWrapper<Basis>;

cls.def(pybind11::init([](const Basis& basis) { return new DirichletValues(basis); }), pybind11::keep_alive<1, 2>());
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
scopedf, "LocalView", Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapper", Dune::MetaType<Basis>()),
includes);

// Eigen::Ref needed due to https://pybind11.readthedocs.io/en/stable/advanced/cast/eigen.html#pass-by-reference
cls.def("fixBoundaryDOFs",
[](DirichletValues& self, const std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int)>& f) {
auto lambda = [&](BackendType& vec, const MultiIndex& indexGlobal) {
// we explicitly only allow flat indices
f(vec.vector(), indexGlobal[0]);
};
self.fixBoundaryDOFs(lambda);
});
if (isNew) {
lv.def("bind", &LocalViewWrapper::bind);
lv.def("unbind", &LocalViewWrapper::unbind);
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });

cls.def("fixBoundaryDOFsUsingLocalView",
[](DirichletValues& self,
const std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int, LocalViewWrapper&)>& f) {
auto lambda = [&](BackendType& vec, int localIndex, LocalView& lv) {
auto lvWrapper = LocalViewWrapper(lv.globalBasis());
// this can be simplified when
// https://gitlab.dune-project.org/staging/dune-functions/-/merge_requests/418 becomes available
pybind11::object obj = pybind11::cast(lv.element());
lvWrapper.bind(obj);
f(vec.vector(), localIndex, lvWrapper);
};
self.fixBoundaryDOFs(lambda);
});
Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
}

cls.def(pybind11::init([](const Basis& basis) { return new DirichletValues(basis); }), pybind11::keep_alive<1, 2>());

cls.def(
"fixBoundaryDOFsUsingLocalViewAndIntersection",
[](DirichletValues& self,
const std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int, LocalViewWrapper&, const Intersection&)>& f) {
auto lambda = [&](BackendType& vec, int localIndex, LocalView& lv, const Intersection& intersection) {
auto lvWrapper = LocalViewWrapper(lv.globalBasis());
// this can be simplified when
// https://gitlab.dune-project.org/staging/dune-functions/-/merge_requests/418 becomes available
pybind11::object obj = pybind11::cast(lv.element());
lvWrapper.bind(obj);
f(vec.vector(), localIndex, lvWrapper, intersection);
};
self.fixBoundaryDOFs(lambda);
});
cls.def_property_readonly("container", &DirichletValues::container);
rath3t marked this conversation as resolved.
Show resolved Hide resolved
cls.def_property_readonly("size", &DirichletValues::size);
cls.def("__len__", &DirichletValues::size);
cls.def_property_readonly("fixedDOFsize", &DirichletValues::fixedDOFsize);
cls.def("isConstrained", [](DirichletValues& self, std::size_t i) -> bool { return self.isConstrained(i); });
cls.def("setSingleDOF", [](DirichletValues& self, std::size_t i, bool flag) { self.setSingleDOF(i, flag); });
henrij22 marked this conversation as resolved.
Show resolved Hide resolved
cls.def("isConstrained", [](DirichletValues& self, MultiIndex i) -> bool { return self.isConstrained(i); });
cls.def("setSingleDOF", [](DirichletValues& self, MultiIndex i, bool flag) { self.setSingleDOF(i, flag); });
cls.def("reset", &DirichletValues::reset);

cls.def("fixDOFs",
[](DirichletValues& self, const std::function<void(const Basis&, Eigen::Ref<Eigen::VectorX<bool>>)>& f) {
auto lambda = [&](const Basis& basis, BackendType& vec) {
// we explicitly only allow flat indices
f(basis, vec.vector());
};
auto lambda = [&](const Basis& basis, BackendType& vec) { f(basis, vec.vector()); };
self.fixDOFs(lambda);
});
}
Expand Down
33 changes: 17 additions & 16 deletions ikarus/python/finiteelements/fe.hh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace Ikarus::Python {
* \throws Dune::NotImplemented If the specified resultType is not supported by the finite element.
*/
template <class FE, class... options>
void registerCalculateAt(pybind11::handle scope, pybind11::class_<FE, options...> cls, auto restultTypesTuple) {
void registerCalculateAt(pybind11::handle scope, pybind11::class_<FE, options...> cls, auto resultTypesTuple) {
using Traits = typename FE::Traits;
using FERequirements = typename FE::Requirement;
cls.def(
Expand All @@ -57,7 +57,7 @@ void registerCalculateAt(pybind11::handle scope, pybind11::class_<FE, options...
std::string resType) {
Eigen::VectorXd result;
bool success = false;
Dune::Hybrid::forEach(restultTypesTuple, [&]<typename RT>(RT i) {
Dune::Hybrid::forEach(resultTypesTuple, [&]<typename RT>(RT i) {
if (resType == toString(i)) {
success = true;
result = self.template calculateAt<RT::template Rebind>(req, local).asVec();
Expand Down Expand Up @@ -113,21 +113,22 @@ void registerFE(pybind11::handle scope, pybind11::class_<FE, options...> cls) {
pybind11::arg("Requirement"), pybind11::arg("MatrixAffordance"), pybind11::arg("elementMatrix").noconvert());

pybind11::module scopedf = pybind11::module::import("dune.functions");
using LocalViewWrapper = Dune::Python::LocalViewWrapper<FlatBasis>;

typedef Dune::Python::LocalViewWrapper<FlatBasis> LocalViewWrapper;
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
auto lv = Dune::Python::insertClass<LocalViewWrapper>(
scopedf, "LocalViewWrapper",
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapperWrapper", Dune::MetaType<FlatBasis>()),
includes)
.first;
lv.def("bind", &LocalViewWrapper::bind);
lv.def("unbind", &LocalViewWrapper::unbind);
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });

Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
scopedf, "LocalViewWrapper",
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapperWrapper", Dune::MetaType<FlatBasis>()), includes);

if (isNew) {
lv.def("bind", &LocalViewWrapper::bind);
lv.def("unbind", &LocalViewWrapper::unbind);
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
rath3t marked this conversation as resolved.
Show resolved Hide resolved
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });

Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
}

cls.def(
"localView",
Expand Down
20 changes: 11 additions & 9 deletions ikarus/python/finiteelements/registerferequirements.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <dune/python/common/typeregistry.hh>
#include <dune/python/pybind11/pybind11.h>

#include <ikarus/python/finiteelements/scalarwrapper.hh>
Expand All @@ -22,26 +23,27 @@ void registerFERequirement(pybind11::handle scope, pybind11::class_<FE, options.
"createRequirement", [](pybind11::object /* self */) { return FERequirements(); },
pybind11::return_value_policy::copy);

auto includes = Dune::Python::IncludeFiles{"ikarus/finiteelements/ferequirements.hh"};
auto [lv, isNotRegistered] = Dune::Python::insertClass<FERequirements>(
auto includes = Dune::Python::IncludeFiles{"ikarus/finiteelements/ferequirements.hh"};
auto [req, isNew] = Dune::Python::insertClass<FERequirements>(
scope, "FERequirements", Dune::Python::GenerateTypeName(Dune::className<FERequirements>()), includes);
if (isNotRegistered) {
lv.def(pybind11::init());
lv.def(pybind11::init<SolutionVectorType&, ParameterType&>());

lv.def(
if (isNew) {
req.def(pybind11::init());
req.def(pybind11::init<SolutionVectorType&, ParameterType&>());

req.def(
"insertGlobalSolution",
[](FERequirements& self, SolutionVectorType solVec) { self.insertGlobalSolution(solVec); },
"solutionVector"_a.noconvert());
lv.def(
req.def(
"globalSolution", [](FERequirements& self) { return self.globalSolution(); },
pybind11::return_value_policy::reference_internal);
lv.def(
req.def(
"insertParameter",
[](FERequirements& self, ScalarWrapper<double>& parVal) { self.insertParameter(parVal.value()); },
pybind11::keep_alive<1, 2>(), "parameterValue"_a.noconvert());

lv.def("parameter", [](const FERequirements& self) { return self.parameter(); });
req.def("parameter", [](const FERequirements& self) { return self.parameter(); });
}
}
} // namespace Ikarus::Python
11 changes: 11 additions & 0 deletions ikarus/python/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ dune_python_add_test(
python
)

dune_python_add_test(
NAME
pydirichletvalues
SCRIPT
dirichletvaluetest.py
WORKING_DIRECTORY
${CMAKE_CURRENT_SOURCE_DIR}
LABELS
python
)

if(HAVE_DUNE_IGA)
dune_python_add_test(
NAME
Expand Down
Loading
Loading