average.hpp 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. /// @file
  2. /// Calculation of biinvariant means.
  3. #ifndef SOPHUS_AVERAGE_HPP
  4. #define SOPHUS_AVERAGE_HPP
  5. #include "common.hpp"
  6. #include "rxso2.hpp"
  7. #include "rxso3.hpp"
  8. #include "se2.hpp"
  9. #include "se3.hpp"
  10. #include "sim2.hpp"
  11. #include "sim3.hpp"
  12. #include "so2.hpp"
  13. #include "so3.hpp"
  14. namespace Sophus {
  15. /// Calculates mean iteratively.
  16. ///
  17. /// Returns ``nullopt`` if it does not converge.
  18. ///
  19. template <class SequenceContainer>
  20. optional<typename SequenceContainer::value_type> iterativeMean(
  21. SequenceContainer const& foo_Ts_bar, int max_num_iterations) {
  22. size_t N = foo_Ts_bar.size();
  23. SOPHUS_ENSURE(N >= 1, "N must be >= 1.");
  24. using Group = typename SequenceContainer::value_type;
  25. using Scalar = typename Group::Scalar;
  26. using Tangent = typename Group::Tangent;
  27. // This implements the algorithm in the beginning of Sec. 4.2 in
  28. // ftp://ftp-sop.inria.fr/epidaure/Publications/Arsigny/arsigny_rr_biinvariant_average.pdf.
  29. Group foo_T_average = foo_Ts_bar.front();
  30. Scalar w = Scalar(1. / N);
  31. for (int i = 0; i < max_num_iterations; ++i) {
  32. Tangent average;
  33. setToZero<Tangent>(average);
  34. for (Group const& foo_T_bar : foo_Ts_bar) {
  35. average += w * (foo_T_average.inverse() * foo_T_bar).log();
  36. }
  37. Group foo_T_newaverage = foo_T_average * Group::exp(average);
  38. if (squaredNorm<Tangent>(
  39. (foo_T_newaverage.inverse() * foo_T_average).log()) <
  40. Constants<Scalar>::epsilon()) {
  41. return foo_T_newaverage;
  42. }
  43. foo_T_average = foo_T_newaverage;
  44. }
  45. // LCOV_EXCL_START
  46. return nullopt;
  47. // LCOV_EXCL_STOP
  48. }
  49. #ifdef DOXYGEN_SHOULD_SKIP_THIS
  50. /// Mean implementation for any Lie group.
  51. template <class SequenceContainer, class Scalar>
  52. optional<typename SequenceContainer::value_type> average(
  53. SequenceContainer const& foo_Ts_bar);
  54. #else
  55. // Mean implementation for SO(2).
  56. template <class SequenceContainer,
  57. class Scalar = typename SequenceContainer::value_type::Scalar>
  58. enable_if_t<
  59. std::is_same<typename SequenceContainer::value_type, SO2<Scalar> >::value,
  60. optional<typename SequenceContainer::value_type> >
  61. average(SequenceContainer const& foo_Ts_bar) {
  62. // This implements rotational part of Proposition 12 from Sec. 6.2 of
  63. // ftp://ftp-sop.inria.fr/epidaure/Publications/Arsigny/arsigny_rr_biinvariant_average.pdf.
  64. size_t N = std::distance(std::begin(foo_Ts_bar), std::end(foo_Ts_bar));
  65. SOPHUS_ENSURE(N >= 1, "N must be >= 1.");
  66. SO2<Scalar> foo_T_average = foo_Ts_bar.front();
  67. Scalar w = Scalar(1. / N);
  68. Scalar average(0);
  69. for (SO2<Scalar> const& foo_T_bar : foo_Ts_bar) {
  70. average += w * (foo_T_average.inverse() * foo_T_bar).log();
  71. }
  72. return foo_T_average * SO2<Scalar>::exp(average);
  73. }
  74. // Mean implementation for RxSO(2).
  75. template <class SequenceContainer,
  76. class Scalar = typename SequenceContainer::value_type::Scalar>
  77. enable_if_t<
  78. std::is_same<typename SequenceContainer::value_type, RxSO2<Scalar> >::value,
  79. optional<typename SequenceContainer::value_type> >
  80. average(SequenceContainer const& foo_Ts_bar) {
  81. size_t N = std::distance(std::begin(foo_Ts_bar), std::end(foo_Ts_bar));
  82. SOPHUS_ENSURE(N >= 1, "N must be >= 1.");
  83. RxSO2<Scalar> foo_T_average = foo_Ts_bar.front();
  84. Scalar w = Scalar(1. / N);
  85. Vector2<Scalar> average(Scalar(0), Scalar(0));
  86. for (RxSO2<Scalar> const& foo_T_bar : foo_Ts_bar) {
  87. average += w * (foo_T_average.inverse() * foo_T_bar).log();
  88. }
  89. return foo_T_average * RxSO2<Scalar>::exp(average);
  90. }
  91. namespace details {
  92. template <class T>
  93. void getQuaternion(T const&);
  94. template <class Scalar>
  95. Eigen::Quaternion<Scalar> getUnitQuaternion(SO3<Scalar> const& R) {
  96. return R.unit_quaternion();
  97. }
  98. template <class Scalar>
  99. Eigen::Quaternion<Scalar> getUnitQuaternion(RxSO3<Scalar> const& sR) {
  100. return sR.so3().unit_quaternion();
  101. }
  102. template <class SequenceContainer,
  103. class Scalar = typename SequenceContainer::value_type::Scalar>
  104. Eigen::Quaternion<Scalar> averageUnitQuaternion(
  105. SequenceContainer const& foo_Ts_bar) {
  106. // This: http://stackoverflow.com/a/27410865/1221742
  107. size_t N = std::distance(std::begin(foo_Ts_bar), std::end(foo_Ts_bar));
  108. SOPHUS_ENSURE(N >= 1, "N must be >= 1.");
  109. Eigen::Matrix<Scalar, 4, Eigen::Dynamic> Q(4, N);
  110. int i = 0;
  111. Scalar w = Scalar(1. / N);
  112. for (auto const& foo_T_bar : foo_Ts_bar) {
  113. Q.col(i) = w * details::getUnitQuaternion(foo_T_bar).coeffs();
  114. ++i;
  115. }
  116. Eigen::Matrix<Scalar, 4, 4> QQt = Q * Q.transpose();
  117. // TODO: Figure out why we can't use SelfAdjointEigenSolver here.
  118. Eigen::EigenSolver<Eigen::Matrix<Scalar, 4, 4> > es(QQt);
  119. std::complex<Scalar> max_eigenvalue = es.eigenvalues()[0];
  120. Eigen::Matrix<std::complex<Scalar>, 4, 1> max_eigenvector =
  121. es.eigenvectors().col(0);
  122. for (int i = 1; i < 4; i++) {
  123. if (std::norm(es.eigenvalues()[i]) > std::norm(max_eigenvalue)) {
  124. max_eigenvalue = es.eigenvalues()[i];
  125. max_eigenvector = es.eigenvectors().col(i);
  126. }
  127. }
  128. Eigen::Quaternion<Scalar> quat;
  129. quat.coeffs() << //
  130. max_eigenvector[0].real(), //
  131. max_eigenvector[1].real(), //
  132. max_eigenvector[2].real(), //
  133. max_eigenvector[3].real();
  134. return quat;
  135. }
  136. } // namespace details
  137. // Mean implementation for SO(3).
  138. //
  139. // TODO: Detect degenerated cases and return nullopt.
  140. template <class SequenceContainer,
  141. class Scalar = typename SequenceContainer::value_type::Scalar>
  142. enable_if_t<
  143. std::is_same<typename SequenceContainer::value_type, SO3<Scalar> >::value,
  144. optional<typename SequenceContainer::value_type> >
  145. average(SequenceContainer const& foo_Ts_bar) {
  146. return SO3<Scalar>(details::averageUnitQuaternion(foo_Ts_bar));
  147. }
  148. // Mean implementation for R x SO(3).
  149. template <class SequenceContainer,
  150. class Scalar = typename SequenceContainer::value_type::Scalar>
  151. enable_if_t<
  152. std::is_same<typename SequenceContainer::value_type, RxSO3<Scalar> >::value,
  153. optional<typename SequenceContainer::value_type> >
  154. average(SequenceContainer const& foo_Ts_bar) {
  155. size_t N = std::distance(std::begin(foo_Ts_bar), std::end(foo_Ts_bar));
  156. SOPHUS_ENSURE(N >= 1, "N must be >= 1.");
  157. Scalar scale_sum = Scalar(0);
  158. using std::exp;
  159. using std::log;
  160. for (RxSO3<Scalar> const& foo_T_bar : foo_Ts_bar) {
  161. scale_sum += log(foo_T_bar.scale());
  162. }
  163. return RxSO3<Scalar>(exp(scale_sum / Scalar(N)),
  164. SO3<Scalar>(details::averageUnitQuaternion(foo_Ts_bar)));
  165. }
  166. template <class SequenceContainer,
  167. class Scalar = typename SequenceContainer::value_type::Scalar>
  168. enable_if_t<
  169. std::is_same<typename SequenceContainer::value_type, SE2<Scalar> >::value,
  170. optional<typename SequenceContainer::value_type> >
  171. average(SequenceContainer const& foo_Ts_bar, int max_num_iterations = 20) {
  172. // TODO: Implement Proposition 12 from Sec. 6.2 of
  173. // ftp://ftp-sop.inria.fr/epidaure/Publications/Arsigny/arsigny_rr_biinvariant_average.pdf.
  174. return iterativeMean(foo_Ts_bar, max_num_iterations);
  175. }
  176. template <class SequenceContainer,
  177. class Scalar = typename SequenceContainer::value_type::Scalar>
  178. enable_if_t<
  179. std::is_same<typename SequenceContainer::value_type, Sim2<Scalar> >::value,
  180. optional<typename SequenceContainer::value_type> >
  181. average(SequenceContainer const& foo_Ts_bar, int max_num_iterations = 20) {
  182. return iterativeMean(foo_Ts_bar, max_num_iterations);
  183. }
  184. template <class SequenceContainer,
  185. class Scalar = typename SequenceContainer::value_type::Scalar>
  186. enable_if_t<
  187. std::is_same<typename SequenceContainer::value_type, SE3<Scalar> >::value,
  188. optional<typename SequenceContainer::value_type> >
  189. average(SequenceContainer const& foo_Ts_bar, int max_num_iterations = 20) {
  190. return iterativeMean(foo_Ts_bar, max_num_iterations);
  191. }
  192. template <class SequenceContainer,
  193. class Scalar = typename SequenceContainer::value_type::Scalar>
  194. enable_if_t<
  195. std::is_same<typename SequenceContainer::value_type, Sim3<Scalar> >::value,
  196. optional<typename SequenceContainer::value_type> >
  197. average(SequenceContainer const& foo_Ts_bar, int max_num_iterations = 20) {
  198. return iterativeMean(foo_Ts_bar, max_num_iterations);
  199. }
  200. #endif // DOXYGEN_SHOULD_SKIP_THIS
  201. } // namespace Sophus
  202. #endif // SOPHUS_AVERAGE_HPP