dense_qr.cc 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. // Ceres Solver - A fast non-linear least squares minimizer
  2. // Copyright 2022 Google Inc. All rights reserved.
  3. // http://ceres-solver.org/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are met:
  7. //
  8. // * Redistributions of source code must retain the above copyright notice,
  9. // this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above copyright notice,
  11. // this list of conditions and the following disclaimer in the documentation
  12. // and/or other materials provided with the distribution.
  13. // * Neither the name of Google Inc. nor the names of its contributors may be
  14. // used to endorse or promote products derived from this software without
  15. // specific prior written permission.
  16. //
  17. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  22. // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  23. // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  24. // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  25. // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  26. // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  27. // POSSIBILITY OF SUCH DAMAGE.
  28. //
  29. // Author: sameeragarwal@google.com (Sameer Agarwal)
  30. #include "ceres/dense_qr.h"
  31. #include <algorithm>
  32. #include <memory>
  33. #include <string>
  34. #ifndef CERES_NO_CUDA
  35. #include "ceres/context_impl.h"
  36. #include "cublas_v2.h"
  37. #include "cusolverDn.h"
  38. #endif // CERES_NO_CUDA
  39. #ifndef CERES_NO_LAPACK
  40. // LAPACK routines for solving a linear least squares problem using QR
  41. // factorization. This is done in three stages:
  42. //
  43. // A * x = b
  44. // Q * R * x = b (dgeqrf)
  45. // R * x = Q' * b (dormqr)
  46. // x = R^{-1} * Q'* b (dtrtrs)
  47. // clang-format off
  48. // Compute the QR factorization of a.
  49. //
  50. // a is an m x n column major matrix (Denoted by "A" in the above description)
  51. // lda is the leading dimension of a. lda >= max(1, num_rows)
  52. // tau is an array of size min(m,n). It contains the scalar factors of the
  53. // elementary reflectors.
  54. // work is an array of size max(1,lwork). On exit, if info=0, work[0] contains
  55. // the optimal size of work.
  56. //
  57. // if lwork >= 1 it is the size of work. If lwork = -1, then a workspace query is assumed.
  58. // dgeqrf computes the optimal size of the work array and returns it as work[0].
  59. //
  60. // info = 0, successful exit.
  61. // info < 0, if info = -i, then the i^th argument had illegal value.
  62. extern "C" void dgeqrf_(const int* m, const int* n, double* a, const int* lda,
  63. double* tau, double* work, const int* lwork, int* info);
  64. // Apply Q or Q' to b.
  65. //
  66. // b is a m times n column major matrix.
  67. // size = 'L' applies Q or Q' on the left, size = 'R' applies Q or Q' on the right.
  68. // trans = 'N', applies Q, trans = 'T', applies Q'.
  69. // k is the number of elementary reflectors whose product defines the matrix Q.
  70. // If size = 'L', m >= k >= 0 and if side = 'R', n >= k >= 0.
  71. // a is an lda x k column major matrix containing the reflectors as returned by dgeqrf.
  72. // ldb is the leading dimension of b.
  73. // work is an array of size max(1, lwork)
  74. // lwork if positive is the size of work. If lwork = -1, then a
  75. // workspace query is assumed.
  76. //
  77. // info = 0, successful exit.
  78. // info < 0, if info = -i, then the i^th argument had illegal value.
  79. extern "C" void dormqr_(const char* side, const char* trans, const int* m,
  80. const int* n ,const int* k, double* a, const int* lda,
  81. double* tau, double* b, const int* ldb, double* work,
  82. const int* lwork, int* info);
  83. // Solve a triangular system of the form A * x = b
  84. //
  85. // uplo = 'U', A is upper triangular. uplo = 'L' is lower triangular.
  86. // trans = 'N', 'T', 'C' specifies the form - A, A^T, A^H.
  87. // DIAG = 'N', A is not unit triangular. 'U' is unit triangular.
  88. // n is the order of the matrix A.
  89. // nrhs number of columns of b.
  90. // a is a column major lda x n.
  91. // b is a column major matrix of ldb x nrhs
  92. //
  93. // info = 0 successful.
  94. // = -i < 0 i^th argument is an illegal value.
  95. // = i > 0, i^th diagonal element of A is zero.
  96. extern "C" void dtrtrs_(const char* uplo, const char* trans, const char* diag,
  97. const int* n, const int* nrhs, double* a, const int* lda,
  98. double* b, const int* ldb, int* info);
  99. // clang-format on
  100. #endif
  101. namespace ceres::internal {
  102. DenseQR::~DenseQR() = default;
  103. std::unique_ptr<DenseQR> DenseQR::Create(const LinearSolver::Options& options) {
  104. std::unique_ptr<DenseQR> dense_qr;
  105. switch (options.dense_linear_algebra_library_type) {
  106. case EIGEN:
  107. dense_qr = std::make_unique<EigenDenseQR>();
  108. break;
  109. case LAPACK:
  110. #ifndef CERES_NO_LAPACK
  111. dense_qr = std::make_unique<LAPACKDenseQR>();
  112. break;
  113. #else
  114. LOG(FATAL) << "Ceres was compiled without support for LAPACK.";
  115. #endif
  116. case CUDA:
  117. #ifndef CERES_NO_CUDA
  118. dense_qr = CUDADenseQR::Create(options);
  119. break;
  120. #else
  121. LOG(FATAL) << "Ceres was compiled without support for CUDA.";
  122. #endif
  123. default:
  124. LOG(FATAL) << "Unknown dense linear algebra library type : "
  125. << DenseLinearAlgebraLibraryTypeToString(
  126. options.dense_linear_algebra_library_type);
  127. }
  128. return dense_qr;
  129. }
  130. LinearSolverTerminationType DenseQR::FactorAndSolve(int num_rows,
  131. int num_cols,
  132. double* lhs,
  133. const double* rhs,
  134. double* solution,
  135. std::string* message) {
  136. LinearSolverTerminationType termination_type =
  137. Factorize(num_rows, num_cols, lhs, message);
  138. if (termination_type == LinearSolverTerminationType::SUCCESS) {
  139. termination_type = Solve(rhs, solution, message);
  140. }
  141. return termination_type;
  142. }
  143. LinearSolverTerminationType EigenDenseQR::Factorize(int num_rows,
  144. int num_cols,
  145. double* lhs,
  146. std::string* message) {
  147. Eigen::Map<ColMajorMatrix> m(lhs, num_rows, num_cols);
  148. qr_ = std::make_unique<QRType>(m);
  149. *message = "Success.";
  150. return LinearSolverTerminationType::SUCCESS;
  151. }
  152. LinearSolverTerminationType EigenDenseQR::Solve(const double* rhs,
  153. double* solution,
  154. std::string* message) {
  155. VectorRef(solution, qr_->cols()) =
  156. qr_->solve(ConstVectorRef(rhs, qr_->rows()));
  157. *message = "Success.";
  158. return LinearSolverTerminationType::SUCCESS;
  159. }
  160. #ifndef CERES_NO_LAPACK
  161. LinearSolverTerminationType LAPACKDenseQR::Factorize(int num_rows,
  162. int num_cols,
  163. double* lhs,
  164. std::string* message) {
  165. int lwork = -1;
  166. double work_size;
  167. int info = 0;
  168. // Compute the size of the temporary workspace needed to compute the QR
  169. // factorization in the dgeqrf call below.
  170. dgeqrf_(&num_rows,
  171. &num_cols,
  172. lhs_,
  173. &num_rows,
  174. tau_.data(),
  175. &work_size,
  176. &lwork,
  177. &info);
  178. if (info < 0) {
  179. LOG(FATAL) << "Congratulations, you found a bug in Ceres."
  180. << "Please report it."
  181. << "LAPACK::dgels fatal error."
  182. << "Argument: " << -info << " is invalid.";
  183. }
  184. lhs_ = lhs;
  185. num_rows_ = num_rows;
  186. num_cols_ = num_cols;
  187. lwork = static_cast<int>(work_size);
  188. if (work_.size() < lwork) {
  189. work_.resize(lwork);
  190. }
  191. if (tau_.size() < num_cols) {
  192. tau_.resize(num_cols);
  193. }
  194. if (q_transpose_rhs_.size() < num_rows) {
  195. q_transpose_rhs_.resize(num_rows);
  196. }
  197. // Factorize the lhs_ using the workspace that we just constructed above.
  198. dgeqrf_(&num_rows,
  199. &num_cols,
  200. lhs_,
  201. &num_rows,
  202. tau_.data(),
  203. work_.data(),
  204. &lwork,
  205. &info);
  206. if (info < 0) {
  207. LOG(FATAL) << "Congratulations, you found a bug in Ceres."
  208. << "Please report it. dgeqrf fatal error."
  209. << "Argument: " << -info << " is invalid.";
  210. }
  211. termination_type_ = LinearSolverTerminationType::SUCCESS;
  212. *message = "Success.";
  213. return termination_type_;
  214. }
  215. LinearSolverTerminationType LAPACKDenseQR::Solve(const double* rhs,
  216. double* solution,
  217. std::string* message) {
  218. if (termination_type_ != LinearSolverTerminationType::SUCCESS) {
  219. *message = "QR factorization failed and solve called.";
  220. return termination_type_;
  221. }
  222. std::copy_n(rhs, num_rows_, q_transpose_rhs_.data());
  223. const char side = 'L';
  224. char trans = 'T';
  225. const int num_c_cols = 1;
  226. const int lwork = work_.size();
  227. int info = 0;
  228. dormqr_(&side,
  229. &trans,
  230. &num_rows_,
  231. &num_c_cols,
  232. &num_cols_,
  233. lhs_,
  234. &num_rows_,
  235. tau_.data(),
  236. q_transpose_rhs_.data(),
  237. &num_rows_,
  238. work_.data(),
  239. &lwork,
  240. &info);
  241. if (info < 0) {
  242. LOG(FATAL) << "Congratulations, you found a bug in Ceres."
  243. << "Please report it. dormr fatal error."
  244. << "Argument: " << -info << " is invalid.";
  245. }
  246. const char uplo = 'U';
  247. trans = 'N';
  248. const char diag = 'N';
  249. dtrtrs_(&uplo,
  250. &trans,
  251. &diag,
  252. &num_cols_,
  253. &num_c_cols,
  254. lhs_,
  255. &num_rows_,
  256. q_transpose_rhs_.data(),
  257. &num_rows_,
  258. &info);
  259. if (info < 0) {
  260. LOG(FATAL) << "Congratulations, you found a bug in Ceres."
  261. << "Please report it. dormr fatal error."
  262. << "Argument: " << -info << " is invalid.";
  263. } else if (info > 0) {
  264. *message =
  265. "QR factorization failure. The factorization is not full rank. R has "
  266. "zeros on the diagonal.";
  267. termination_type_ = LinearSolverTerminationType::FAILURE;
  268. } else {
  269. std::copy_n(q_transpose_rhs_.data(), num_cols_, solution);
  270. termination_type_ = LinearSolverTerminationType::SUCCESS;
  271. }
  272. return termination_type_;
  273. }
  274. #endif // CERES_NO_LAPACK
  275. #ifndef CERES_NO_CUDA
  276. CUDADenseQR::CUDADenseQR(ContextImpl* context)
  277. : context_(context),
  278. lhs_{context},
  279. rhs_{context},
  280. tau_{context},
  281. device_workspace_{context},
  282. error_(context, 1) {}
  283. LinearSolverTerminationType CUDADenseQR::Factorize(int num_rows,
  284. int num_cols,
  285. double* lhs,
  286. std::string* message) {
  287. factorize_result_ = LinearSolverTerminationType::FATAL_ERROR;
  288. lhs_.Reserve(num_rows * num_cols);
  289. tau_.Reserve(std::min(num_rows, num_cols));
  290. num_rows_ = num_rows;
  291. num_cols_ = num_cols;
  292. lhs_.CopyFromCpu(lhs, num_rows * num_cols);
  293. int device_workspace_size = 0;
  294. if (cusolverDnDgeqrf_bufferSize(context_->cusolver_handle_,
  295. num_rows,
  296. num_cols,
  297. lhs_.data(),
  298. num_rows,
  299. &device_workspace_size) !=
  300. CUSOLVER_STATUS_SUCCESS) {
  301. *message = "cuSolverDN::cusolverDnDgeqrf_bufferSize failed.";
  302. return LinearSolverTerminationType::FATAL_ERROR;
  303. }
  304. device_workspace_.Reserve(device_workspace_size);
  305. if (cusolverDnDgeqrf(context_->cusolver_handle_,
  306. num_rows,
  307. num_cols,
  308. lhs_.data(),
  309. num_rows,
  310. tau_.data(),
  311. reinterpret_cast<double*>(device_workspace_.data()),
  312. device_workspace_.size(),
  313. error_.data()) != CUSOLVER_STATUS_SUCCESS) {
  314. *message = "cuSolverDN::cusolverDnDgeqrf failed.";
  315. return LinearSolverTerminationType::FATAL_ERROR;
  316. }
  317. int error = 0;
  318. error_.CopyToCpu(&error, 1);
  319. if (error < 0) {
  320. LOG(FATAL) << "Congratulations, you found a bug in Ceres - "
  321. << "please report it. "
  322. << "cuSolverDN::cusolverDnDgeqrf fatal error. "
  323. << "Argument: " << -error << " is invalid.";
  324. // The following line is unreachable, but return failure just to be
  325. // pedantic, since the compiler does not know that.
  326. return LinearSolverTerminationType::FATAL_ERROR;
  327. }
  328. *message = "Success";
  329. factorize_result_ = LinearSolverTerminationType::SUCCESS;
  330. return LinearSolverTerminationType::SUCCESS;
  331. }
  332. LinearSolverTerminationType CUDADenseQR::Solve(const double* rhs,
  333. double* solution,
  334. std::string* message) {
  335. if (factorize_result_ != LinearSolverTerminationType::SUCCESS) {
  336. *message = "Factorize did not complete successfully previously.";
  337. return factorize_result_;
  338. }
  339. rhs_.CopyFromCpu(rhs, num_rows_);
  340. int device_workspace_size = 0;
  341. if (cusolverDnDormqr_bufferSize(context_->cusolver_handle_,
  342. CUBLAS_SIDE_LEFT,
  343. CUBLAS_OP_T,
  344. num_rows_,
  345. 1,
  346. num_cols_,
  347. lhs_.data(),
  348. num_rows_,
  349. tau_.data(),
  350. rhs_.data(),
  351. num_rows_,
  352. &device_workspace_size) !=
  353. CUSOLVER_STATUS_SUCCESS) {
  354. *message = "cuSolverDN::cusolverDnDormqr_bufferSize failed.";
  355. return LinearSolverTerminationType::FATAL_ERROR;
  356. }
  357. device_workspace_.Reserve(device_workspace_size);
  358. // Compute rhs = Q^T * rhs, assuming that lhs has already been factorized.
  359. // The result of factorization would have stored Q in a packed form in lhs_.
  360. if (cusolverDnDormqr(context_->cusolver_handle_,
  361. CUBLAS_SIDE_LEFT,
  362. CUBLAS_OP_T,
  363. num_rows_,
  364. 1,
  365. num_cols_,
  366. lhs_.data(),
  367. num_rows_,
  368. tau_.data(),
  369. rhs_.data(),
  370. num_rows_,
  371. reinterpret_cast<double*>(device_workspace_.data()),
  372. device_workspace_.size(),
  373. error_.data()) != CUSOLVER_STATUS_SUCCESS) {
  374. *message = "cuSolverDN::cusolverDnDormqr failed.";
  375. return LinearSolverTerminationType::FATAL_ERROR;
  376. }
  377. int error = 0;
  378. error_.CopyToCpu(&error, 1);
  379. if (error < 0) {
  380. LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
  381. << "Please report it."
  382. << "cuSolverDN::cusolverDnDormqr fatal error. "
  383. << "Argument: " << -error << " is invalid.";
  384. }
  385. // Compute the solution vector as x = R \ (Q^T * rhs). Since the previous step
  386. // replaced rhs by (Q^T * rhs), this is just x = R \ rhs.
  387. if (cublasDtrsv(context_->cublas_handle_,
  388. CUBLAS_FILL_MODE_UPPER,
  389. CUBLAS_OP_N,
  390. CUBLAS_DIAG_NON_UNIT,
  391. num_cols_,
  392. lhs_.data(),
  393. num_rows_,
  394. rhs_.data(),
  395. 1) != CUBLAS_STATUS_SUCCESS) {
  396. *message = "cuBLAS::cublasDtrsv failed.";
  397. return LinearSolverTerminationType::FATAL_ERROR;
  398. }
  399. rhs_.CopyToCpu(solution, num_cols_);
  400. *message = "Success";
  401. return LinearSolverTerminationType::SUCCESS;
  402. }
  403. std::unique_ptr<CUDADenseQR> CUDADenseQR::Create(
  404. const LinearSolver::Options& options) {
  405. if (options.dense_linear_algebra_library_type != CUDA ||
  406. options.context == nullptr || !options.context->IsCudaInitialized()) {
  407. return nullptr;
  408. }
  409. return std::unique_ptr<CUDADenseQR>(new CUDADenseQR(options.context));
  410. }
  411. #endif // CERES_NO_CUDA
  412. } // namespace ceres::internal