dense_cholesky.cc 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. // Ceres Solver - A fast non-linear least squares minimizer
  2. // Copyright 2022 Google Inc. All rights reserved.
  3. // http://ceres-solver.org/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are met:
  7. //
  8. // * Redistributions of source code must retain the above copyright notice,
  9. // this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above copyright notice,
  11. // this list of conditions and the following disclaimer in the documentation
  12. // and/or other materials provided with the distribution.
  13. // * Neither the name of Google Inc. nor the names of its contributors may be
  14. // used to endorse or promote products derived from this software without
  15. // specific prior written permission.
  16. //
  17. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  22. // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  23. // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  24. // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  25. // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  26. // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  27. // POSSIBILITY OF SUCH DAMAGE.
  28. //
  29. // Author: sameeragarwal@google.com (Sameer Agarwal)
  30. #include "ceres/dense_cholesky.h"
  31. #include <algorithm>
  32. #include <memory>
  33. #include <string>
  34. #include <utility>
  35. #include <vector>
  36. #include "ceres/internal/config.h"
  37. #include "ceres/iterative_refiner.h"
  38. #ifndef CERES_NO_CUDA
  39. #include "ceres/context_impl.h"
  40. #include "ceres/cuda_kernels_vector_ops.h"
  41. #include "cuda_runtime.h"
  42. #include "cusolverDn.h"
  43. #endif // CERES_NO_CUDA
  44. #ifndef CERES_NO_LAPACK
  45. // C interface to the LAPACK Cholesky factorization and triangular solve.
  46. extern "C" void dpotrf_(
  47. const char* uplo, const int* n, double* a, const int* lda, int* info);
  48. extern "C" void dpotrs_(const char* uplo,
  49. const int* n,
  50. const int* nrhs,
  51. const double* a,
  52. const int* lda,
  53. double* b,
  54. const int* ldb,
  55. int* info);
  56. extern "C" void spotrf_(
  57. const char* uplo, const int* n, float* a, const int* lda, int* info);
  58. extern "C" void spotrs_(const char* uplo,
  59. const int* n,
  60. const int* nrhs,
  61. const float* a,
  62. const int* lda,
  63. float* b,
  64. const int* ldb,
  65. int* info);
  66. #endif
  67. namespace ceres::internal {
  68. DenseCholesky::~DenseCholesky() = default;
  69. std::unique_ptr<DenseCholesky> DenseCholesky::Create(
  70. const LinearSolver::Options& options) {
  71. std::unique_ptr<DenseCholesky> dense_cholesky;
  72. switch (options.dense_linear_algebra_library_type) {
  73. case EIGEN:
  74. // Eigen mixed precision solver not yet implemented.
  75. if (options.use_mixed_precision_solves) {
  76. dense_cholesky = std::make_unique<FloatEigenDenseCholesky>();
  77. } else {
  78. dense_cholesky = std::make_unique<EigenDenseCholesky>();
  79. }
  80. break;
  81. case LAPACK:
  82. #ifndef CERES_NO_LAPACK
  83. // LAPACK mixed precision solver not yet implemented.
  84. if (options.use_mixed_precision_solves) {
  85. dense_cholesky = std::make_unique<FloatLAPACKDenseCholesky>();
  86. } else {
  87. dense_cholesky = std::make_unique<LAPACKDenseCholesky>();
  88. }
  89. break;
  90. #else
  91. LOG(FATAL) << "Ceres was compiled without support for LAPACK.";
  92. #endif
  93. case CUDA:
  94. #ifndef CERES_NO_CUDA
  95. if (options.use_mixed_precision_solves) {
  96. dense_cholesky = CUDADenseCholeskyMixedPrecision::Create(options);
  97. } else {
  98. dense_cholesky = CUDADenseCholesky::Create(options);
  99. }
  100. break;
  101. #else
  102. LOG(FATAL) << "Ceres was compiled without support for CUDA.";
  103. #endif
  104. default:
  105. LOG(FATAL) << "Unknown dense linear algebra library type : "
  106. << DenseLinearAlgebraLibraryTypeToString(
  107. options.dense_linear_algebra_library_type);
  108. }
  109. if (options.max_num_refinement_iterations > 0) {
  110. auto refiner = std::make_unique<DenseIterativeRefiner>(
  111. options.max_num_refinement_iterations);
  112. dense_cholesky = std::make_unique<RefinedDenseCholesky>(
  113. std::move(dense_cholesky), std::move(refiner));
  114. }
  115. return dense_cholesky;
  116. }
  117. LinearSolverTerminationType DenseCholesky::FactorAndSolve(
  118. int num_cols,
  119. double* lhs,
  120. const double* rhs,
  121. double* solution,
  122. std::string* message) {
  123. LinearSolverTerminationType termination_type =
  124. Factorize(num_cols, lhs, message);
  125. if (termination_type == LinearSolverTerminationType::SUCCESS) {
  126. termination_type = Solve(rhs, solution, message);
  127. }
  128. return termination_type;
  129. }
  130. LinearSolverTerminationType EigenDenseCholesky::Factorize(
  131. int num_cols, double* lhs, std::string* message) {
  132. Eigen::Map<Eigen::MatrixXd> m(lhs, num_cols, num_cols);
  133. llt_ = std::make_unique<LLTType>(m);
  134. if (llt_->info() != Eigen::Success) {
  135. *message = "Eigen failure. Unable to perform dense Cholesky factorization.";
  136. return LinearSolverTerminationType::FAILURE;
  137. }
  138. *message = "Success.";
  139. return LinearSolverTerminationType::SUCCESS;
  140. }
  141. LinearSolverTerminationType EigenDenseCholesky::Solve(const double* rhs,
  142. double* solution,
  143. std::string* message) {
  144. if (llt_->info() != Eigen::Success) {
  145. *message = "Eigen failure. Unable to perform dense Cholesky factorization.";
  146. return LinearSolverTerminationType::FAILURE;
  147. }
  148. VectorRef(solution, llt_->cols()) =
  149. llt_->solve(ConstVectorRef(rhs, llt_->cols()));
  150. *message = "Success.";
  151. return LinearSolverTerminationType::SUCCESS;
  152. }
  153. LinearSolverTerminationType FloatEigenDenseCholesky::Factorize(
  154. int num_cols, double* lhs, std::string* message) {
  155. // TODO(sameeragarwal): Check if this causes a double allocation.
  156. lhs_ = Eigen::Map<Eigen::MatrixXd>(lhs, num_cols, num_cols).cast<float>();
  157. llt_ = std::make_unique<LLTType>(lhs_);
  158. if (llt_->info() != Eigen::Success) {
  159. *message = "Eigen failure. Unable to perform dense Cholesky factorization.";
  160. return LinearSolverTerminationType::FAILURE;
  161. }
  162. *message = "Success.";
  163. return LinearSolverTerminationType::SUCCESS;
  164. }
  165. LinearSolverTerminationType FloatEigenDenseCholesky::Solve(
  166. const double* rhs, double* solution, std::string* message) {
  167. if (llt_->info() != Eigen::Success) {
  168. *message = "Eigen failure. Unable to perform dense Cholesky factorization.";
  169. return LinearSolverTerminationType::FAILURE;
  170. }
  171. rhs_ = ConstVectorRef(rhs, llt_->cols()).cast<float>();
  172. solution_ = llt_->solve(rhs_);
  173. VectorRef(solution, llt_->cols()) = solution_.cast<double>();
  174. *message = "Success.";
  175. return LinearSolverTerminationType::SUCCESS;
  176. }
  177. #ifndef CERES_NO_LAPACK
  178. LinearSolverTerminationType LAPACKDenseCholesky::Factorize(
  179. int num_cols, double* lhs, std::string* message) {
  180. lhs_ = lhs;
  181. num_cols_ = num_cols;
  182. const char uplo = 'L';
  183. int info = 0;
  184. dpotrf_(&uplo, &num_cols_, lhs_, &num_cols_, &info);
  185. if (info < 0) {
  186. termination_type_ = LinearSolverTerminationType::FATAL_ERROR;
  187. LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
  188. << "Please report it. "
  189. << "LAPACK::dpotrf fatal error. "
  190. << "Argument: " << -info << " is invalid.";
  191. } else if (info > 0) {
  192. termination_type_ = LinearSolverTerminationType::FAILURE;
  193. *message = StringPrintf(
  194. "LAPACK::dpotrf numerical failure. "
  195. "The leading minor of order %d is not positive definite.",
  196. info);
  197. } else {
  198. termination_type_ = LinearSolverTerminationType::SUCCESS;
  199. *message = "Success.";
  200. }
  201. return termination_type_;
  202. }
  203. LinearSolverTerminationType LAPACKDenseCholesky::Solve(const double* rhs,
  204. double* solution,
  205. std::string* message) {
  206. const char uplo = 'L';
  207. const int nrhs = 1;
  208. int info = 0;
  209. VectorRef(solution, num_cols_) = ConstVectorRef(rhs, num_cols_);
  210. dpotrs_(
  211. &uplo, &num_cols_, &nrhs, lhs_, &num_cols_, solution, &num_cols_, &info);
  212. if (info < 0) {
  213. termination_type_ = LinearSolverTerminationType::FATAL_ERROR;
  214. LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
  215. << "Please report it. "
  216. << "LAPACK::dpotrs fatal error. "
  217. << "Argument: " << -info << " is invalid.";
  218. }
  219. *message = "Success";
  220. termination_type_ = LinearSolverTerminationType::SUCCESS;
  221. return termination_type_;
  222. }
  223. LinearSolverTerminationType FloatLAPACKDenseCholesky::Factorize(
  224. int num_cols, double* lhs, std::string* message) {
  225. num_cols_ = num_cols;
  226. lhs_ = Eigen::Map<Eigen::MatrixXd>(lhs, num_cols, num_cols).cast<float>();
  227. const char uplo = 'L';
  228. int info = 0;
  229. spotrf_(&uplo, &num_cols_, lhs_.data(), &num_cols_, &info);
  230. if (info < 0) {
  231. termination_type_ = LinearSolverTerminationType::FATAL_ERROR;
  232. LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
  233. << "Please report it. "
  234. << "LAPACK::spotrf fatal error. "
  235. << "Argument: " << -info << " is invalid.";
  236. } else if (info > 0) {
  237. termination_type_ = LinearSolverTerminationType::FAILURE;
  238. *message = StringPrintf(
  239. "LAPACK::spotrf numerical failure. "
  240. "The leading minor of order %d is not positive definite.",
  241. info);
  242. } else {
  243. termination_type_ = LinearSolverTerminationType::SUCCESS;
  244. *message = "Success.";
  245. }
  246. return termination_type_;
  247. }
  248. LinearSolverTerminationType FloatLAPACKDenseCholesky::Solve(
  249. const double* rhs, double* solution, std::string* message) {
  250. const char uplo = 'L';
  251. const int nrhs = 1;
  252. int info = 0;
  253. rhs_and_solution_ = ConstVectorRef(rhs, num_cols_).cast<float>();
  254. spotrs_(&uplo,
  255. &num_cols_,
  256. &nrhs,
  257. lhs_.data(),
  258. &num_cols_,
  259. rhs_and_solution_.data(),
  260. &num_cols_,
  261. &info);
  262. if (info < 0) {
  263. termination_type_ = LinearSolverTerminationType::FATAL_ERROR;
  264. LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
  265. << "Please report it. "
  266. << "LAPACK::dpotrs fatal error. "
  267. << "Argument: " << -info << " is invalid.";
  268. }
  269. *message = "Success";
  270. termination_type_ = LinearSolverTerminationType::SUCCESS;
  271. VectorRef(solution, num_cols_) =
  272. rhs_and_solution_.head(num_cols_).cast<double>();
  273. return termination_type_;
  274. }
  275. #endif // CERES_NO_LAPACK
  276. RefinedDenseCholesky::RefinedDenseCholesky(
  277. std::unique_ptr<DenseCholesky> dense_cholesky,
  278. std::unique_ptr<DenseIterativeRefiner> iterative_refiner)
  279. : dense_cholesky_(std::move(dense_cholesky)),
  280. iterative_refiner_(std::move(iterative_refiner)) {}
  281. RefinedDenseCholesky::~RefinedDenseCholesky() = default;
  282. LinearSolverTerminationType RefinedDenseCholesky::Factorize(
  283. const int num_cols, double* lhs, std::string* message) {
  284. lhs_ = lhs;
  285. num_cols_ = num_cols;
  286. return dense_cholesky_->Factorize(num_cols, lhs, message);
  287. }
  288. LinearSolverTerminationType RefinedDenseCholesky::Solve(const double* rhs,
  289. double* solution,
  290. std::string* message) {
  291. CHECK(lhs_ != nullptr);
  292. auto termination_type = dense_cholesky_->Solve(rhs, solution, message);
  293. if (termination_type != LinearSolverTerminationType::SUCCESS) {
  294. return termination_type;
  295. }
  296. iterative_refiner_->Refine(
  297. num_cols_, lhs_, rhs, dense_cholesky_.get(), solution);
  298. return LinearSolverTerminationType::SUCCESS;
  299. }
  300. #ifndef CERES_NO_CUDA
  301. CUDADenseCholesky::CUDADenseCholesky(ContextImpl* context)
  302. : context_(context),
  303. lhs_{context},
  304. rhs_{context},
  305. device_workspace_{context},
  306. error_(context, 1) {}
  307. LinearSolverTerminationType CUDADenseCholesky::Factorize(int num_cols,
  308. double* lhs,
  309. std::string* message) {
  310. factorize_result_ = LinearSolverTerminationType::FATAL_ERROR;
  311. lhs_.Reserve(num_cols * num_cols);
  312. num_cols_ = num_cols;
  313. lhs_.CopyFromCpu(lhs, num_cols * num_cols);
  314. int device_workspace_size = 0;
  315. if (cusolverDnDpotrf_bufferSize(context_->cusolver_handle_,
  316. CUBLAS_FILL_MODE_LOWER,
  317. num_cols,
  318. lhs_.data(),
  319. num_cols,
  320. &device_workspace_size) !=
  321. CUSOLVER_STATUS_SUCCESS) {
  322. *message = "cuSolverDN::cusolverDnDpotrf_bufferSize failed.";
  323. return LinearSolverTerminationType::FATAL_ERROR;
  324. }
  325. device_workspace_.Reserve(device_workspace_size);
  326. if (cusolverDnDpotrf(context_->cusolver_handle_,
  327. CUBLAS_FILL_MODE_LOWER,
  328. num_cols,
  329. lhs_.data(),
  330. num_cols,
  331. reinterpret_cast<double*>(device_workspace_.data()),
  332. device_workspace_.size(),
  333. error_.data()) != CUSOLVER_STATUS_SUCCESS) {
  334. *message = "cuSolverDN::cusolverDnDpotrf failed.";
  335. return LinearSolverTerminationType::FATAL_ERROR;
  336. }
  337. int error = 0;
  338. error_.CopyToCpu(&error, 1);
  339. if (error < 0) {
  340. LOG(FATAL) << "Congratulations, you found a bug in Ceres - "
  341. << "please report it. "
  342. << "cuSolverDN::cusolverDnXpotrf fatal error. "
  343. << "Argument: " << -error << " is invalid.";
  344. // The following line is unreachable, but return failure just to be
  345. // pedantic, since the compiler does not know that.
  346. return LinearSolverTerminationType::FATAL_ERROR;
  347. } else if (error > 0) {
  348. *message = StringPrintf(
  349. "cuSolverDN::cusolverDnDpotrf numerical failure. "
  350. "The leading minor of order %d is not positive definite.",
  351. error);
  352. factorize_result_ = LinearSolverTerminationType::FAILURE;
  353. return LinearSolverTerminationType::FAILURE;
  354. }
  355. *message = "Success";
  356. factorize_result_ = LinearSolverTerminationType::SUCCESS;
  357. return LinearSolverTerminationType::SUCCESS;
  358. }
  359. LinearSolverTerminationType CUDADenseCholesky::Solve(const double* rhs,
  360. double* solution,
  361. std::string* message) {
  362. if (factorize_result_ != LinearSolverTerminationType::SUCCESS) {
  363. *message = "Factorize did not complete successfully previously.";
  364. return factorize_result_;
  365. }
  366. rhs_.CopyFromCpu(rhs, num_cols_);
  367. if (cusolverDnDpotrs(context_->cusolver_handle_,
  368. CUBLAS_FILL_MODE_LOWER,
  369. num_cols_,
  370. 1,
  371. lhs_.data(),
  372. num_cols_,
  373. rhs_.data(),
  374. num_cols_,
  375. error_.data()) != CUSOLVER_STATUS_SUCCESS) {
  376. *message = "cuSolverDN::cusolverDnDpotrs failed.";
  377. return LinearSolverTerminationType::FATAL_ERROR;
  378. }
  379. int error = 0;
  380. error_.CopyToCpu(&error, 1);
  381. if (error != 0) {
  382. LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
  383. << "Please report it."
  384. << "cuSolverDN::cusolverDnDpotrs fatal error. "
  385. << "Argument: " << -error << " is invalid.";
  386. }
  387. rhs_.CopyToCpu(solution, num_cols_);
  388. *message = "Success";
  389. return LinearSolverTerminationType::SUCCESS;
  390. }
  391. std::unique_ptr<CUDADenseCholesky> CUDADenseCholesky::Create(
  392. const LinearSolver::Options& options) {
  393. if (options.dense_linear_algebra_library_type != CUDA ||
  394. options.context == nullptr || !options.context->IsCudaInitialized()) {
  395. return nullptr;
  396. }
  397. return std::unique_ptr<CUDADenseCholesky>(
  398. new CUDADenseCholesky(options.context));
  399. }
  400. std::unique_ptr<CUDADenseCholeskyMixedPrecision>
  401. CUDADenseCholeskyMixedPrecision::Create(const LinearSolver::Options& options) {
  402. if (options.dense_linear_algebra_library_type != CUDA ||
  403. !options.use_mixed_precision_solves || options.context == nullptr ||
  404. !options.context->IsCudaInitialized()) {
  405. return nullptr;
  406. }
  407. return std::unique_ptr<CUDADenseCholeskyMixedPrecision>(
  408. new CUDADenseCholeskyMixedPrecision(
  409. options.context, options.max_num_refinement_iterations));
  410. }
  411. LinearSolverTerminationType
  412. CUDADenseCholeskyMixedPrecision::CudaCholeskyFactorize(std::string* message) {
  413. int device_workspace_size = 0;
  414. if (cusolverDnSpotrf_bufferSize(context_->cusolver_handle_,
  415. CUBLAS_FILL_MODE_LOWER,
  416. num_cols_,
  417. lhs_fp32_.data(),
  418. num_cols_,
  419. &device_workspace_size) !=
  420. CUSOLVER_STATUS_SUCCESS) {
  421. *message = "cuSolverDN::cusolverDnSpotrf_bufferSize failed.";
  422. return LinearSolverTerminationType::FATAL_ERROR;
  423. }
  424. device_workspace_.Reserve(device_workspace_size);
  425. if (cusolverDnSpotrf(context_->cusolver_handle_,
  426. CUBLAS_FILL_MODE_LOWER,
  427. num_cols_,
  428. lhs_fp32_.data(),
  429. num_cols_,
  430. device_workspace_.data(),
  431. device_workspace_.size(),
  432. error_.data()) != CUSOLVER_STATUS_SUCCESS) {
  433. *message = "cuSolverDN::cusolverDnSpotrf failed.";
  434. return LinearSolverTerminationType::FATAL_ERROR;
  435. }
  436. int error = 0;
  437. error_.CopyToCpu(&error, 1);
  438. if (error < 0) {
  439. LOG(FATAL) << "Congratulations, you found a bug in Ceres - "
  440. << "please report it. "
  441. << "cuSolverDN::cusolverDnSpotrf fatal error. "
  442. << "Argument: " << -error << " is invalid.";
  443. // The following line is unreachable, but return failure just to be
  444. // pedantic, since the compiler does not know that.
  445. return LinearSolverTerminationType::FATAL_ERROR;
  446. }
  447. if (error > 0) {
  448. *message = StringPrintf(
  449. "cuSolverDN::cusolverDnSpotrf numerical failure. "
  450. "The leading minor of order %d is not positive definite.",
  451. error);
  452. factorize_result_ = LinearSolverTerminationType::FAILURE;
  453. return LinearSolverTerminationType::FAILURE;
  454. }
  455. *message = "Success";
  456. return LinearSolverTerminationType::SUCCESS;
  457. }
  458. LinearSolverTerminationType CUDADenseCholeskyMixedPrecision::CudaCholeskySolve(
  459. std::string* message) {
  460. CHECK_EQ(cudaMemcpyAsync(correction_fp32_.data(),
  461. residual_fp32_.data(),
  462. num_cols_ * sizeof(float),
  463. cudaMemcpyDeviceToDevice,
  464. context_->DefaultStream()),
  465. cudaSuccess);
  466. if (cusolverDnSpotrs(context_->cusolver_handle_,
  467. CUBLAS_FILL_MODE_LOWER,
  468. num_cols_,
  469. 1,
  470. lhs_fp32_.data(),
  471. num_cols_,
  472. correction_fp32_.data(),
  473. num_cols_,
  474. error_.data()) != CUSOLVER_STATUS_SUCCESS) {
  475. *message = "cuSolverDN::cusolverDnDpotrs failed.";
  476. return LinearSolverTerminationType::FATAL_ERROR;
  477. }
  478. int error = 0;
  479. error_.CopyToCpu(&error, 1);
  480. if (error != 0) {
  481. LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
  482. << "Please report it."
  483. << "cuSolverDN::cusolverDnDpotrs fatal error. "
  484. << "Argument: " << -error << " is invalid.";
  485. }
  486. *message = "Success";
  487. return LinearSolverTerminationType::SUCCESS;
  488. }
  489. CUDADenseCholeskyMixedPrecision::CUDADenseCholeskyMixedPrecision(
  490. ContextImpl* context, int max_num_refinement_iterations)
  491. : context_(context),
  492. lhs_fp64_{context},
  493. rhs_fp64_{context},
  494. lhs_fp32_{context},
  495. device_workspace_{context},
  496. error_(context, 1),
  497. x_fp64_{context},
  498. correction_fp32_{context},
  499. residual_fp32_{context},
  500. residual_fp64_{context},
  501. max_num_refinement_iterations_(max_num_refinement_iterations) {}
  502. LinearSolverTerminationType CUDADenseCholeskyMixedPrecision::Factorize(
  503. int num_cols, double* lhs, std::string* message) {
  504. num_cols_ = num_cols;
  505. // Copy fp64 version of lhs to GPU.
  506. lhs_fp64_.Reserve(num_cols * num_cols);
  507. lhs_fp64_.CopyFromCpu(lhs, num_cols * num_cols);
  508. // Create an fp32 copy of lhs, lhs_fp32.
  509. lhs_fp32_.Reserve(num_cols * num_cols);
  510. CudaFP64ToFP32(lhs_fp64_.data(),
  511. lhs_fp32_.data(),
  512. num_cols * num_cols,
  513. context_->DefaultStream());
  514. // Factorize lhs_fp32.
  515. factorize_result_ = CudaCholeskyFactorize(message);
  516. return factorize_result_;
  517. }
  518. LinearSolverTerminationType CUDADenseCholeskyMixedPrecision::Solve(
  519. const double* rhs, double* solution, std::string* message) {
  520. // If factorization failed, return failure.
  521. if (factorize_result_ != LinearSolverTerminationType::SUCCESS) {
  522. *message = "Factorize did not complete successfully previously.";
  523. return factorize_result_;
  524. }
  525. // Reserve memory for all arrays.
  526. rhs_fp64_.Reserve(num_cols_);
  527. x_fp64_.Reserve(num_cols_);
  528. correction_fp32_.Reserve(num_cols_);
  529. residual_fp32_.Reserve(num_cols_);
  530. residual_fp64_.Reserve(num_cols_);
  531. // Initialize x = 0.
  532. CudaSetZeroFP64(x_fp64_.data(), num_cols_, context_->DefaultStream());
  533. // Initialize residual = rhs.
  534. rhs_fp64_.CopyFromCpu(rhs, num_cols_);
  535. residual_fp64_.CopyFromGPUArray(rhs_fp64_.data(), num_cols_);
  536. for (int i = 0; i <= max_num_refinement_iterations_; ++i) {
  537. // Cast residual from fp64 to fp32.
  538. CudaFP64ToFP32(residual_fp64_.data(),
  539. residual_fp32_.data(),
  540. num_cols_,
  541. context_->DefaultStream());
  542. // [fp32] c = lhs^-1 * residual.
  543. auto result = CudaCholeskySolve(message);
  544. if (result != LinearSolverTerminationType::SUCCESS) {
  545. return result;
  546. }
  547. // [fp64] x += c.
  548. CudaDsxpy(x_fp64_.data(),
  549. correction_fp32_.data(),
  550. num_cols_,
  551. context_->DefaultStream());
  552. if (i < max_num_refinement_iterations_) {
  553. // [fp64] residual = rhs - lhs * x
  554. // This is done in two steps:
  555. // 1. [fp64] residual = rhs
  556. residual_fp64_.CopyFromGPUArray(rhs_fp64_.data(), num_cols_);
  557. // 2. [fp64] residual = residual - lhs * x
  558. double alpha = -1.0;
  559. double beta = 1.0;
  560. cublasDsymv(context_->cublas_handle_,
  561. CUBLAS_FILL_MODE_LOWER,
  562. num_cols_,
  563. &alpha,
  564. lhs_fp64_.data(),
  565. num_cols_,
  566. x_fp64_.data(),
  567. 1,
  568. &beta,
  569. residual_fp64_.data(),
  570. 1);
  571. }
  572. }
  573. x_fp64_.CopyToCpu(solution, num_cols_);
  574. *message = "Success.";
  575. return LinearSolverTerminationType::SUCCESS;
  576. }
  577. #endif // CERES_NO_CUDA
  578. } // namespace ceres::internal