From afe43aaa758342fa53081431af6753e7fa08f0d5 Mon Sep 17 00:00:00 2001 From: artivis Date: Sun, 21 Jan 2024 19:40:21 +0100 Subject: [PATCH] fix casting float->double closes #282 --- include/manif/impl/cast.h | 35 +++++++++++++++++++++++++ include/manif/impl/lie_group_base.h | 5 +++- include/manif/impl/se2/SE2_base.h | 11 ++++++++ include/manif/impl/se3/SE3_base.h | 16 ++++++++++- include/manif/impl/se_2_3/SE_2_3_base.h | 17 ++++++++++++ include/manif/impl/so2/SO2_base.h | 9 +++++++ include/manif/impl/so3/SO3_base.h | 15 ++++++++++- test/common_tester.h | 25 +++++++++++++++--- test/se2/gtest_se2_map.cpp | 4 +-- test/so3/gtest_so3.cpp | 6 ----- 10 files changed, 129 insertions(+), 14 deletions(-) create mode 100644 include/manif/impl/cast.h diff --git a/include/manif/impl/cast.h b/include/manif/impl/cast.h new file mode 100644 index 00000000..60a17da9 --- /dev/null +++ b/include/manif/impl/cast.h @@ -0,0 +1,35 @@ +#ifndef _MANIF_MANIF_IMPL_CAST_H_ +#define _MANIF_MANIF_IMPL_CAST_H_ + +namespace manif { +namespace internal { + +template +struct CastEvaluatorImpl { + template + static auto run(const T& o) -> typename T::template LieGroupTemplate { + return typename T::template LieGroupTemplate( + o.coeffs().template cast() + ); + } +}; + +template +struct CastEvaluator : CastEvaluatorImpl { + using Base = CastEvaluatorImpl; + + CastEvaluator(const Derived& xptr) : xptr_(xptr) {} + + auto run() const -> typename Derived::template LieGroupTemplate { + return Base::run(xptr_); + } + +protected: + + const Derived& xptr_; +}; + +} // namespace internal +} // namespace manif + +#endif // _MANIF_MANIF_IMPL_CAST_H_ diff --git a/include/manif/impl/lie_group_base.h b/include/manif/impl/lie_group_base.h index 23ec128c..945669a2 100644 --- a/include/manif/impl/lie_group_base.h +++ b/include/manif/impl/lie_group_base.h @@ -6,6 +6,7 @@ #include "manif/impl/eigen.h" #include "manif/impl/tangent_base.h" #include "manif/impl/assignment_assert.h" +#include "manif/impl/cast.h" #include "manif/constants.h" @@ -415,7 +416,9 @@ template typename LieGroupBase<_Derived>::template LieGroupTemplate<_NewScalar> LieGroupBase<_Derived>::cast() const { - return LieGroupTemplate<_NewScalar>(coeffs().template cast<_NewScalar>()); + return internal::CastEvaluator< + typename internal::traits<_Derived>::Base, _NewScalar + >(derived()).run(); } template diff --git a/include/manif/impl/se2/SE2_base.h b/include/manif/impl/se2/SE2_base.h index fbd56e58..b4fbd1ed 100644 --- a/include/manif/impl/se2/SE2_base.h +++ b/include/manif/impl/se2/SE2_base.h @@ -424,6 +424,17 @@ struct AssignmentEvaluatorImpl> } }; +//! @brief Cast specialization for SE2Base objects. +template +struct CastEvaluatorImpl, NewScalar> { + template + static auto run(const T& o) -> typename Derived::template LieGroupTemplate { + return typename Derived::template LieGroupTemplate( + NewScalar(o.x()), NewScalar(o.y()), NewScalar(o.angle()) + ); + } +}; + } /* namespace internal */ } /* namespace manif */ diff --git a/include/manif/impl/se3/SE3_base.h b/include/manif/impl/se3/SE3_base.h index cc84a652..f7a69c30 100644 --- a/include/manif/impl/se3/SE3_base.h +++ b/include/manif/impl/se3/SE3_base.h @@ -443,7 +443,7 @@ struct RandomEvaluatorImpl> } }; -//! @brief Assignment assert specialization for SE2Base objects +//! @brief Assignment assert specialization for SE3Base objects template struct AssignmentEvaluatorImpl> { @@ -461,6 +461,20 @@ struct AssignmentEvaluatorImpl> } }; +//! @brief Cast specialization for SE3Base objects. +template +struct CastEvaluatorImpl, NewScalar> { + template + static auto run(const T& o) -> typename Derived::template LieGroupTemplate { + const typename SE3Base::QuaternionDataType q = o.quat(); + const typename SE3Base::Translation t = o.translation(); + + return typename Derived::template LieGroupTemplate( + t.template cast(), q.template cast().normalized() + ); + } +}; + } /* namespace internal */ } /* namespace manif */ diff --git a/include/manif/impl/se_2_3/SE_2_3_base.h b/include/manif/impl/se_2_3/SE_2_3_base.h index 01920c50..329155ce 100644 --- a/include/manif/impl/se_2_3/SE_2_3_base.h +++ b/include/manif/impl/se_2_3/SE_2_3_base.h @@ -468,6 +468,23 @@ struct AssignmentEvaluatorImpl> } }; +//! @brief Cast specialization for SE_2_3Base objects. +template +struct CastEvaluatorImpl, NewScalar> { + template + static auto run(const T& o) -> typename Derived::template LieGroupTemplate { + const typename SE_2_3Base::QuaternionDataType q = o.quat(); + const typename SE_2_3Base::Translation t = o.translation(); + const typename SE_2_3Base::LinearVelocity v = o.linearVelocity(); + + return typename Derived::template LieGroupTemplate( + t.template cast(), + q.template cast().normalized(), + v.template cast() + ); + } +}; + } /* namespace internal */ } /* namespace manif */ diff --git a/include/manif/impl/so2/SO2_base.h b/include/manif/impl/so2/SO2_base.h index 82919522..ea1506ab 100644 --- a/include/manif/impl/so2/SO2_base.h +++ b/include/manif/impl/so2/SO2_base.h @@ -340,6 +340,15 @@ struct AssignmentEvaluatorImpl> } }; +//! @brief Cast specialization for SO2Base objects. +template +struct CastEvaluatorImpl, NewScalar> { + template + static auto run(const T& o) -> typename Derived::template LieGroupTemplate { + return typename Derived::template LieGroupTemplate(NewScalar(o.angle())); + } +}; + } /* namespace internal */ } /* namespace manif */ diff --git a/include/manif/impl/so3/SO3_base.h b/include/manif/impl/so3/SO3_base.h index 303c22c9..308649b5 100644 --- a/include/manif/impl/so3/SO3_base.h +++ b/include/manif/impl/so3/SO3_base.h @@ -403,7 +403,7 @@ struct RandomEvaluatorImpl> } }; -//! @brief Assignment assert specialization for SE2Base objects +//! @brief Assignment assert specialization for SO3Base objects template struct AssignmentEvaluatorImpl> { @@ -421,6 +421,19 @@ struct AssignmentEvaluatorImpl> } }; +//! @brief Cast specialization for SO3Base objects. +template +struct CastEvaluatorImpl, NewScalar> { + template + static auto run(const T& o) -> typename Derived::template LieGroupTemplate { + const typename SO3Base::QuaternionDataType q = o.quat(); + + return typename Derived::template LieGroupTemplate( + q.template cast().normalized() + ); + } +}; + } /* namespace internal */ } /* namespace manif */ diff --git a/test/common_tester.h b/test/common_tester.h index e67a5205..d47cd9c5 100644 --- a/test/common_tester.h +++ b/test/common_tester.h @@ -104,7 +104,9 @@ TEST_P(TEST_##manifold##_TESTER, TEST_##manifold##_SMALL_ADJ) \ { evalSmallAdj(); } \ TEST_F(TEST_##manifold##_TESTER, TEST_##manifold##_IDENTITY_ACT_POINT) \ - { evalIdentityActPoint(); } + { evalIdentityActPoint(); } \ + TEST_P(TEST_##manifold##_TESTER, TEST_##manifold##_CAST) \ + { evalCast(); } #define MANIF_TEST_JACOBIANS(manifold) \ using manifold##JacobiansTester = JacobianTester; \ @@ -703,6 +705,23 @@ class CommonTester EXPECT_EIGEN_NEAR(pin, pout); } + void evalCast() { + using NewScalar = typename std::conditional< + std::is_same::value, double, float + >::type; + + EXPECT_NO_THROW( + auto state = getState().template cast(); + ); + + int i=0; + EXPECT_NO_THROW( + for (; i < 10000; ++i) { + auto state = LieGroup::Random().template cast(); + } + ) << "+= failed at iteration " << i; + } + protected: // relax eps for float type @@ -1044,8 +1063,8 @@ class JacobianTester Jrinv = tan.rjacinv(); Jlinv = tan.ljacinv(); - EXPECT_EIGEN_NEAR(Jacobian::Identity(), Jr*Jrinv); - EXPECT_EIGEN_NEAR(Jacobian::Identity(), Jl*Jlinv); + EXPECT_EIGEN_NEAR(Jacobian::Identity(), Jr*Jrinv, tol_); + EXPECT_EIGEN_NEAR(Jacobian::Identity(), Jl*Jlinv, tol_); } void evalActJac() diff --git a/test/se2/gtest_se2_map.cpp b/test/se2/gtest_se2_map.cpp index 4011336d..76f52d9c 100644 --- a/test/se2/gtest_se2_map.cpp +++ b/test/se2/gtest_se2_map.cpp @@ -53,13 +53,13 @@ TEST(TEST_SE2, TEST_SE2_MAP_CAST) EXPECT_DOUBLE_EQ(4, se2d.x()); EXPECT_DOUBLE_EQ(2, se2d.y()); - EXPECT_DOUBLE_EQ(MANIF_PI, se2d.angle()); + EXPECT_DOUBLE_EQ(MANIF_PI, std::abs(se2d.angle())); SE2f se2f = se2d.cast(); EXPECT_FLOAT_EQ(4, se2f.x()); EXPECT_FLOAT_EQ(2, se2f.y()); - EXPECT_FLOAT_EQ(MANIF_PI, se2f.angle()); + EXPECT_FLOAT_EQ(MANIF_PI, std::abs(se2f.angle())); } TEST(TEST_SE2, TEST_SE2_MAP_IDENTITY) diff --git a/test/so3/gtest_so3.cpp b/test/so3/gtest_so3.cpp index 9471f6e8..09606f39 100644 --- a/test/so3/gtest_so3.cpp +++ b/test/so3/gtest_so3.cpp @@ -619,12 +619,6 @@ TEST(TEST_SO3, TEST_SO3_NORMALIZE) #endif -MANIF_TEST(SO3f); - -MANIF_TEST_MAP(SO3f); - -MANIF_TEST_JACOBIANS(SO3f); - MANIF_TEST(SO3d); MANIF_TEST_MAP(SO3d);