Skip to content

Commit

Permalink
Fix subspacelocalview
Browse files Browse the repository at this point in the history
  • Loading branch information
henrij22 committed Jul 17, 2024
1 parent b57a5d8 commit a814144
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 55 deletions.
68 changes: 36 additions & 32 deletions ikarus/python/dirichletvalues/dirichletvalues.hh
Original file line number Diff line number Diff line change
Expand Up @@ -39,39 +39,27 @@ namespace Impl {
using FixBoundaryDOFsWithIntersectionFunction =
std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int, LV&, const IS&)>;

template <typename Basis, bool registerBasis = false>
auto registerLocalView() {
template <typename Basis>
auto registerSubSpaceLocalView() {
pybind11::module scopedf = pybind11::module::import("dune.functions");
using LocalView = Dune::Python::LocalViewWrapper<Basis>;
using LocalViewWrapper = Dune::Python::LocalViewWrapper<Basis>;

auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};

// also register subspace basis
if constexpr (registerBasis) {
auto construct = [](const Basis& basis) { return new Basis(basis); };

// This if statement does absolutly nothing
if (Dune::Python::findInTypeRegistry<Basis>().second) {
auto [basisCls, isNotRegistered] = Dune::Python::insertClass<Basis>(
scopedf, "SubspaceBasis", Dune::Python::GenerateTypeName(Dune::className<Basis>()), includes);
if (isNotRegistered)
Dune::Python::registerSubspaceBasis(scopedf, basisCls);
}
// Dune::Python::registerBasisType(scopedf, basisCls, construct, std::false_type{});
} else {
// auto [lv, isNotRegistered] = Dune::Python::insertClass<LocalView>(
// scopedf, "LocalView",
// Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapper", Dune::MetaType<Basis>()), includes);

// if (isNotRegistered) {
// lv.def("bind", &LocalView::bind);
// lv.def("unbind", &LocalView::unbind);
// lv.def("index", [](const LocalView& localView, int index) { return localView.index(index); });
// lv.def("__len__", [](LocalView& self) -> int { return self.size(); });

// Dune::Python::Functions::registerTree<typename LocalView::Tree>(lv);
// lv.def("tree", [](const LocalView& view) { return view.tree(); });
// }
Dune::Python::insertClass<Basis>(scopedf, "SubspaceBasis_" + Dune::className<typename Basis::PrefixPath>(),
Dune::Python::GenerateTypeName(Dune::className<Basis>()), includes);

auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
scopedf, "LocalView_" + Dune::className<typename Basis::PrefixPath>(),
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapper", Dune::MetaType<Basis>()), includes);
if (isNew) {
lv.def("bind", &LocalViewWrapper::bind);
lv.def("unbind", &LocalViewWrapper::unbind);
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });

Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
}
}
} // namespace Impl
Expand All @@ -98,7 +86,7 @@ void forwardCorrectFunction(DirichletValues& dirichletValues, const pybind11::fu
} else if (numParams == 3) {
auto lambda = [&](BackendType& vec, int localIndex, auto&& lv) {
using SubSpaceBasis = typename std::remove_cvref_t<decltype(lv)>::GlobalBasis;
Impl::registerLocalView<SubSpaceBasis, true>();
Impl::registerSubSpaceLocalView<SubSpaceBasis>();

using SubSpaceLocalViewWrapper = Dune::Python::LocalViewWrapper<SubSpaceBasis>;
auto lvWrapper = SubSpaceLocalViewWrapper(lv);
Expand All @@ -112,7 +100,7 @@ void forwardCorrectFunction(DirichletValues& dirichletValues, const pybind11::fu
} else if (numParams == 4) {
auto lambda = [&](BackendType& vec, int localIndex, auto&& lv, const Intersection& intersection) {
using SubSpaceBasis = typename std::remove_cvref_t<decltype(lv)>::GlobalBasis;
Impl::registerLocalView<SubSpaceBasis, true>();
Impl::registerSubSpaceLocalView<SubSpaceBasis>();

using SubSpaceLocalViewWrapper = Dune::Python::LocalViewWrapper<SubSpaceBasis>;
auto lvWrapper = SubSpaceLocalViewWrapper(lv);
Expand Down Expand Up @@ -169,7 +157,23 @@ void registerDirichletValues(pybind11::handle scope, pybind11::class_<DirichletV
using LocalView = typename Basis::LocalView;
using Intersection = typename Basis::GridView::Intersection;

Impl::registerLocalView<Basis>();
pybind11::module scopedf = pybind11::module::import("dune.functions");
using LocalViewWrapper = Dune::Python::LocalViewWrapper<Basis>;

auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
scopedf, "LocalView", Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapper", Dune::MetaType<Basis>()),
includes);

if (isNew) {
lv.def("bind", &LocalViewWrapper::bind);
lv.def("unbind", &LocalViewWrapper::unbind);
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });

Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
}

cls.def(pybind11::init([](const Basis& basis) { return new DirichletValues(basis); }), pybind11::keep_alive<1, 2>());

Expand Down
29 changes: 15 additions & 14 deletions ikarus/python/finiteelements/fe.hh
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,22 @@ void registerFE(pybind11::handle scope, pybind11::class_<FE, options...> cls) {
pybind11::arg("Requirement"), pybind11::arg("MatrixAffordance"), pybind11::arg("elementMatrix").noconvert());

pybind11::module scopedf = pybind11::module::import("dune.functions");
using LocalViewWrapper = Dune::Python::LocalViewWrapper<FlatBasis>;

typedef Dune::Python::LocalViewWrapper<FlatBasis> LocalViewWrapper;
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
auto lv = Dune::Python::insertClass<LocalViewWrapper>(
scopedf, "LocalViewWrapper",
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapperWrapper", Dune::MetaType<FlatBasis>()),
includes)
.first;
lv.def("bind", &LocalViewWrapper::bind);
lv.def("unbind", &LocalViewWrapper::unbind);
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });

Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
scopedf, "LocalViewWrapper",
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapperWrapper", Dune::MetaType<FlatBasis>()), includes);

if (isNew) {
lv.def("bind", &LocalViewWrapper::bind);
lv.def("unbind", &LocalViewWrapper::unbind);
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });

Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
}

cls.def(
"localView",
Expand Down
20 changes: 11 additions & 9 deletions ikarus/python/finiteelements/registerferequirements.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <dune/python/common/typeregistry.hh>
#include <dune/python/pybind11/pybind11.h>

#include <ikarus/python/finiteelements/valuewrapper.hh>
Expand All @@ -22,25 +23,26 @@ void registerFERequirement(pybind11::handle scope, pybind11::class_<FE, options.
"createRequirement", [](pybind11::object /* self */) { return FERequirements(); },
pybind11::return_value_policy::copy);

auto includes = Dune::Python::IncludeFiles{"ikarus/finiteelements/ferequirements.hh"};
auto [lv, isNotRegistered] = Dune::Python::insertClass<FERequirements>(
auto includes = Dune::Python::IncludeFiles{"ikarus/finiteelements/ferequirements.hh"};
auto [req, isNew] = Dune::Python::insertClass<FERequirements>(
scope, "FERequirements", Dune::Python::GenerateTypeName(Dune::className<FERequirements>()), includes);
if (isNotRegistered) {
lv.def(pybind11::init());
lv.def(pybind11::init<SolutionVectorType&, ParameterType&>());

lv.def(
if (isNew) {
req.def(pybind11::init());
req.def(pybind11::init<SolutionVectorType&, ParameterType&>());

req.def(
"insertGlobalSolution",
[](FERequirements& self, SolutionVectorType solVec) { self.insertGlobalSolution(solVec); },
"solutionVector"_a.noconvert());
lv.def(
req.def(
"globalSolution", [](FERequirements& self) { return self.globalSolution(); },
pybind11::return_value_policy::reference_internal);
lv.def(
req.def(
"insertParameter", [](FERequirements& self, ValueWrapper<double>& parVal) { self.insertParameter(parVal.val); },
pybind11::keep_alive<1, 2>(), "parameterValue"_a.noconvert());

lv.def("parameter", [](const FERequirements& self) { return self.parameter(); });
req.def("parameter", [](const FERequirements& self) { return self.parameter(); });
}
}
} // namespace Ikarus::Python
1 change: 1 addition & 0 deletions ikarus/python/test/testdirichletvalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,6 @@ def fixTopSide(vec, localIndex, localView, intersection):
assert dirichletValues2.fixedDOFsize == 0
assert sum(dirichletValues2.container) == 0


if __name__ == "__main__":
testDirichletValues()

0 comments on commit a814144

Please sign in to comment.