tiny_solver_autodiff_function.h 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. // Ceres Solver - A fast non-linear least squares minimizer
  2. // Copyright 2019 Google Inc. All rights reserved.
  3. // http://ceres-solver.org/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are met:
  7. //
  8. // * Redistributions of source code must retain the above copyright notice,
  9. // this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above copyright notice,
  11. // this list of conditions and the following disclaimer in the documentation
  12. // and/or other materials provided with the distribution.
  13. // * Neither the name of Google Inc. nor the names of its contributors may be
  14. // used to endorse or promote products derived from this software without
  15. // specific prior written permission.
  16. //
  17. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  22. // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  23. // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  24. // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  25. // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  26. // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  27. // POSSIBILITY OF SUCH DAMAGE.
  28. //
  29. // Author: mierle@gmail.com (Keir Mierle)
  30. //
  31. // WARNING WARNING WARNING
  32. // WARNING WARNING WARNING Tiny solver is experimental and will change.
  33. // WARNING WARNING WARNING
  34. #ifndef CERES_PUBLIC_TINY_SOLVER_AUTODIFF_FUNCTION_H_
  35. #define CERES_PUBLIC_TINY_SOLVER_AUTODIFF_FUNCTION_H_
  36. #include <memory>
  37. #include <type_traits>
  38. #include "Eigen/Core"
  39. #include "ceres/jet.h"
  40. #include "ceres/types.h" // For kImpossibleValue.
  41. namespace ceres {
  42. // An adapter around autodiff-style CostFunctors to enable easier use of
  43. // TinySolver. See the example below showing how to use it:
  44. //
  45. // // Example for cost functor with static residual size.
  46. // // Same as an autodiff cost functor, but taking only 1 parameter.
  47. // struct MyFunctor {
  48. // template<typename T>
  49. // bool operator()(const T* const parameters, T* residuals) const {
  50. // const T& x = parameters[0];
  51. // const T& y = parameters[1];
  52. // const T& z = parameters[2];
  53. // residuals[0] = x + 2.*y + 4.*z;
  54. // residuals[1] = y * z;
  55. // return true;
  56. // }
  57. // };
  58. //
  59. // typedef TinySolverAutoDiffFunction<MyFunctor, 2, 3>
  60. // AutoDiffFunction;
  61. //
  62. // MyFunctor my_functor;
  63. // AutoDiffFunction f(my_functor);
  64. //
  65. // Vec3 x = ...;
  66. // TinySolver<AutoDiffFunction> solver;
  67. // solver.Solve(f, &x);
  68. //
  69. // // Example for cost functor with dynamic residual size.
  70. // // NumResiduals() supplies dynamic size of residuals.
  71. // // Same functionality as in tiny_solver.h but with autodiff.
  72. // struct MyFunctorWithDynamicResiduals {
  73. // int NumResiduals() const {
  74. // return 2;
  75. // }
  76. //
  77. // template<typename T>
  78. // bool operator()(const T* const parameters, T* residuals) const {
  79. // const T& x = parameters[0];
  80. // const T& y = parameters[1];
  81. // const T& z = parameters[2];
  82. // residuals[0] = x + static_cast<T>(2.)*y + static_cast<T>(4.)*z;
  83. // residuals[1] = y * z;
  84. // return true;
  85. // }
  86. // };
  87. //
  88. // typedef TinySolverAutoDiffFunction<MyFunctorWithDynamicResiduals,
  89. // Eigen::Dynamic,
  90. // 3>
  91. // AutoDiffFunctionWithDynamicResiduals;
  92. //
  93. // MyFunctorWithDynamicResiduals my_functor_dyn;
  94. // AutoDiffFunctionWithDynamicResiduals f(my_functor_dyn);
  95. //
  96. // Vec3 x = ...;
  97. // TinySolver<AutoDiffFunctionWithDynamicResiduals> solver;
  98. // solver.Solve(f, &x);
  99. //
  100. // WARNING: The cost function adapter is not thread safe.
  101. template <typename CostFunctor,
  102. int kNumResiduals,
  103. int kNumParameters,
  104. typename T = double>
  105. class TinySolverAutoDiffFunction {
  106. public:
  107. // This class needs to have an Eigen aligned operator new as it contains
  108. // as a member a Jet type, which itself has a fixed-size Eigen type as member.
  109. EIGEN_MAKE_ALIGNED_OPERATOR_NEW
  110. explicit TinySolverAutoDiffFunction(const CostFunctor& cost_functor)
  111. : cost_functor_(cost_functor) {
  112. Initialize<kNumResiduals>(cost_functor);
  113. }
  114. using Scalar = T;
  115. enum {
  116. NUM_PARAMETERS = kNumParameters,
  117. NUM_RESIDUALS = kNumResiduals,
  118. };
  119. // This is similar to AutoDifferentiate(), but since there is only one
  120. // parameter block it is easier to inline to avoid overhead.
  121. bool operator()(const T* parameters, T* residuals, T* jacobian) const {
  122. if (jacobian == nullptr) {
  123. // No jacobian requested, so just directly call the cost function with
  124. // doubles, skipping jets and derivatives.
  125. return cost_functor_(parameters, residuals);
  126. }
  127. // Initialize the input jets with passed parameters.
  128. for (int i = 0; i < kNumParameters; ++i) {
  129. jet_parameters_[i].a = parameters[i]; // Scalar part.
  130. jet_parameters_[i].v.setZero(); // Derivative part.
  131. jet_parameters_[i].v[i] = T(1.0);
  132. }
  133. // Initialize the output jets such that we can detect user errors.
  134. for (int i = 0; i < num_residuals_; ++i) {
  135. jet_residuals_[i].a = kImpossibleValue;
  136. jet_residuals_[i].v.setConstant(kImpossibleValue);
  137. }
  138. // Execute the cost function, but with jets to find the derivative.
  139. if (!cost_functor_(jet_parameters_, jet_residuals_.data())) {
  140. return false;
  141. }
  142. // Copy the jacobian out of the derivative part of the residual jets.
  143. Eigen::Map<Eigen::Matrix<T, kNumResiduals, kNumParameters>> jacobian_matrix(
  144. jacobian, num_residuals_, kNumParameters);
  145. for (int r = 0; r < num_residuals_; ++r) {
  146. residuals[r] = jet_residuals_[r].a;
  147. // Note that while this looks like a fast vectorized write, in practice it
  148. // unfortunately thrashes the cache since the writes to the column-major
  149. // jacobian are strided (e.g. rows are non-contiguous).
  150. jacobian_matrix.row(r) = jet_residuals_[r].v;
  151. }
  152. return true;
  153. }
  154. int NumResiduals() const {
  155. return num_residuals_; // Set by Initialize.
  156. }
  157. private:
  158. const CostFunctor& cost_functor_;
  159. // The number of residuals at runtime.
  160. // This will be overridden if NUM_RESIDUALS == Eigen::Dynamic.
  161. int num_residuals_ = kNumResiduals;
  162. // To evaluate the cost function with jets, temporary storage is needed. These
  163. // are the buffers that are used during evaluation; parameters for the input,
  164. // and jet_residuals_ are where the final cost and derivatives end up.
  165. //
  166. // Since this buffer is used for evaluation, the adapter is not thread safe.
  167. using JetType = Jet<T, kNumParameters>;
  168. mutable JetType jet_parameters_[kNumParameters];
  169. // Eigen::Matrix serves as static or dynamic container.
  170. mutable Eigen::Matrix<JetType, kNumResiduals, 1> jet_residuals_;
  171. // The number of residuals is dynamically sized and the number of
  172. // parameters is statically sized.
  173. template <int R>
  174. typename std::enable_if<(R == Eigen::Dynamic), void>::type Initialize(
  175. const CostFunctor& function) {
  176. jet_residuals_.resize(function.NumResiduals());
  177. num_residuals_ = function.NumResiduals();
  178. }
  179. // The number of parameters and residuals are statically sized.
  180. template <int R>
  181. typename std::enable_if<(R != Eigen::Dynamic), void>::type Initialize(
  182. const CostFunctor& /* function */) {
  183. num_residuals_ = kNumResiduals;
  184. }
  185. };
  186. } // namespace ceres
  187. #endif // CERES_PUBLIC_TINY_SOLVER_AUTODIFF_FUNCTION_H_