123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456 |
- // Ceres Solver - A fast non-linear least squares minimizer
- // Copyright 2022 Google Inc. All rights reserved.
- // http://ceres-solver.org/
- //
- // Redistribution and use in source and binary forms, with or without
- // modification, are permitted provided that the following conditions are met:
- //
- // * Redistributions of source code must retain the above copyright notice,
- // this list of conditions and the following disclaimer.
- // * Redistributions in binary form must reproduce the above copyright notice,
- // this list of conditions and the following disclaimer in the documentation
- // and/or other materials provided with the distribution.
- // * Neither the name of Google Inc. nor the names of its contributors may be
- // used to endorse or promote products derived from this software without
- // specific prior written permission.
- //
- // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
- // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
- // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
- // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
- // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
- // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
- // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
- // POSSIBILITY OF SUCH DAMAGE.
- //
- // Author: sameeragarwal@google.com (Sameer Agarwal)
- #include "ceres/dense_qr.h"
- #include <algorithm>
- #include <memory>
- #include <string>
- #ifndef CERES_NO_CUDA
- #include "ceres/context_impl.h"
- #include "cublas_v2.h"
- #include "cusolverDn.h"
- #endif // CERES_NO_CUDA
- #ifndef CERES_NO_LAPACK
- // LAPACK routines for solving a linear least squares problem using QR
- // factorization. This is done in three stages:
- //
- // A * x = b
- // Q * R * x = b (dgeqrf)
- // R * x = Q' * b (dormqr)
- // x = R^{-1} * Q'* b (dtrtrs)
- // clang-format off
- // Compute the QR factorization of a.
- //
- // a is an m x n column major matrix (Denoted by "A" in the above description)
- // lda is the leading dimension of a. lda >= max(1, num_rows)
- // tau is an array of size min(m,n). It contains the scalar factors of the
- // elementary reflectors.
- // work is an array of size max(1,lwork). On exit, if info=0, work[0] contains
- // the optimal size of work.
- //
- // if lwork >= 1 it is the size of work. If lwork = -1, then a workspace query is assumed.
- // dgeqrf computes the optimal size of the work array and returns it as work[0].
- //
- // info = 0, successful exit.
- // info < 0, if info = -i, then the i^th argument had illegal value.
- extern "C" void dgeqrf_(const int* m, const int* n, double* a, const int* lda,
- double* tau, double* work, const int* lwork, int* info);
- // Apply Q or Q' to b.
- //
- // b is a m times n column major matrix.
- // size = 'L' applies Q or Q' on the left, size = 'R' applies Q or Q' on the right.
- // trans = 'N', applies Q, trans = 'T', applies Q'.
- // k is the number of elementary reflectors whose product defines the matrix Q.
- // If size = 'L', m >= k >= 0 and if side = 'R', n >= k >= 0.
- // a is an lda x k column major matrix containing the reflectors as returned by dgeqrf.
- // ldb is the leading dimension of b.
- // work is an array of size max(1, lwork)
- // lwork if positive is the size of work. If lwork = -1, then a
- // workspace query is assumed.
- //
- // info = 0, successful exit.
- // info < 0, if info = -i, then the i^th argument had illegal value.
- extern "C" void dormqr_(const char* side, const char* trans, const int* m,
- const int* n ,const int* k, double* a, const int* lda,
- double* tau, double* b, const int* ldb, double* work,
- const int* lwork, int* info);
- // Solve a triangular system of the form A * x = b
- //
- // uplo = 'U', A is upper triangular. uplo = 'L' is lower triangular.
- // trans = 'N', 'T', 'C' specifies the form - A, A^T, A^H.
- // DIAG = 'N', A is not unit triangular. 'U' is unit triangular.
- // n is the order of the matrix A.
- // nrhs number of columns of b.
- // a is a column major lda x n.
- // b is a column major matrix of ldb x nrhs
- //
- // info = 0 successful.
- // = -i < 0 i^th argument is an illegal value.
- // = i > 0, i^th diagonal element of A is zero.
- extern "C" void dtrtrs_(const char* uplo, const char* trans, const char* diag,
- const int* n, const int* nrhs, double* a, const int* lda,
- double* b, const int* ldb, int* info);
- // clang-format on
- #endif
- namespace ceres::internal {
- DenseQR::~DenseQR() = default;
- std::unique_ptr<DenseQR> DenseQR::Create(const LinearSolver::Options& options) {
- std::unique_ptr<DenseQR> dense_qr;
- switch (options.dense_linear_algebra_library_type) {
- case EIGEN:
- dense_qr = std::make_unique<EigenDenseQR>();
- break;
- case LAPACK:
- #ifndef CERES_NO_LAPACK
- dense_qr = std::make_unique<LAPACKDenseQR>();
- break;
- #else
- LOG(FATAL) << "Ceres was compiled without support for LAPACK.";
- #endif
- case CUDA:
- #ifndef CERES_NO_CUDA
- dense_qr = CUDADenseQR::Create(options);
- break;
- #else
- LOG(FATAL) << "Ceres was compiled without support for CUDA.";
- #endif
- default:
- LOG(FATAL) << "Unknown dense linear algebra library type : "
- << DenseLinearAlgebraLibraryTypeToString(
- options.dense_linear_algebra_library_type);
- }
- return dense_qr;
- }
- LinearSolverTerminationType DenseQR::FactorAndSolve(int num_rows,
- int num_cols,
- double* lhs,
- const double* rhs,
- double* solution,
- std::string* message) {
- LinearSolverTerminationType termination_type =
- Factorize(num_rows, num_cols, lhs, message);
- if (termination_type == LinearSolverTerminationType::SUCCESS) {
- termination_type = Solve(rhs, solution, message);
- }
- return termination_type;
- }
- LinearSolverTerminationType EigenDenseQR::Factorize(int num_rows,
- int num_cols,
- double* lhs,
- std::string* message) {
- Eigen::Map<ColMajorMatrix> m(lhs, num_rows, num_cols);
- qr_ = std::make_unique<QRType>(m);
- *message = "Success.";
- return LinearSolverTerminationType::SUCCESS;
- }
- LinearSolverTerminationType EigenDenseQR::Solve(const double* rhs,
- double* solution,
- std::string* message) {
- VectorRef(solution, qr_->cols()) =
- qr_->solve(ConstVectorRef(rhs, qr_->rows()));
- *message = "Success.";
- return LinearSolverTerminationType::SUCCESS;
- }
- #ifndef CERES_NO_LAPACK
- LinearSolverTerminationType LAPACKDenseQR::Factorize(int num_rows,
- int num_cols,
- double* lhs,
- std::string* message) {
- int lwork = -1;
- double work_size;
- int info = 0;
- // Compute the size of the temporary workspace needed to compute the QR
- // factorization in the dgeqrf call below.
- dgeqrf_(&num_rows,
- &num_cols,
- lhs_,
- &num_rows,
- tau_.data(),
- &work_size,
- &lwork,
- &info);
- if (info < 0) {
- LOG(FATAL) << "Congratulations, you found a bug in Ceres."
- << "Please report it."
- << "LAPACK::dgels fatal error."
- << "Argument: " << -info << " is invalid.";
- }
- lhs_ = lhs;
- num_rows_ = num_rows;
- num_cols_ = num_cols;
- lwork = static_cast<int>(work_size);
- if (work_.size() < lwork) {
- work_.resize(lwork);
- }
- if (tau_.size() < num_cols) {
- tau_.resize(num_cols);
- }
- if (q_transpose_rhs_.size() < num_rows) {
- q_transpose_rhs_.resize(num_rows);
- }
- // Factorize the lhs_ using the workspace that we just constructed above.
- dgeqrf_(&num_rows,
- &num_cols,
- lhs_,
- &num_rows,
- tau_.data(),
- work_.data(),
- &lwork,
- &info);
- if (info < 0) {
- LOG(FATAL) << "Congratulations, you found a bug in Ceres."
- << "Please report it. dgeqrf fatal error."
- << "Argument: " << -info << " is invalid.";
- }
- termination_type_ = LinearSolverTerminationType::SUCCESS;
- *message = "Success.";
- return termination_type_;
- }
- LinearSolverTerminationType LAPACKDenseQR::Solve(const double* rhs,
- double* solution,
- std::string* message) {
- if (termination_type_ != LinearSolverTerminationType::SUCCESS) {
- *message = "QR factorization failed and solve called.";
- return termination_type_;
- }
- std::copy_n(rhs, num_rows_, q_transpose_rhs_.data());
- const char side = 'L';
- char trans = 'T';
- const int num_c_cols = 1;
- const int lwork = work_.size();
- int info = 0;
- dormqr_(&side,
- &trans,
- &num_rows_,
- &num_c_cols,
- &num_cols_,
- lhs_,
- &num_rows_,
- tau_.data(),
- q_transpose_rhs_.data(),
- &num_rows_,
- work_.data(),
- &lwork,
- &info);
- if (info < 0) {
- LOG(FATAL) << "Congratulations, you found a bug in Ceres."
- << "Please report it. dormr fatal error."
- << "Argument: " << -info << " is invalid.";
- }
- const char uplo = 'U';
- trans = 'N';
- const char diag = 'N';
- dtrtrs_(&uplo,
- &trans,
- &diag,
- &num_cols_,
- &num_c_cols,
- lhs_,
- &num_rows_,
- q_transpose_rhs_.data(),
- &num_rows_,
- &info);
- if (info < 0) {
- LOG(FATAL) << "Congratulations, you found a bug in Ceres."
- << "Please report it. dormr fatal error."
- << "Argument: " << -info << " is invalid.";
- } else if (info > 0) {
- *message =
- "QR factorization failure. The factorization is not full rank. R has "
- "zeros on the diagonal.";
- termination_type_ = LinearSolverTerminationType::FAILURE;
- } else {
- std::copy_n(q_transpose_rhs_.data(), num_cols_, solution);
- termination_type_ = LinearSolverTerminationType::SUCCESS;
- }
- return termination_type_;
- }
- #endif // CERES_NO_LAPACK
- #ifndef CERES_NO_CUDA
- CUDADenseQR::CUDADenseQR(ContextImpl* context)
- : context_(context),
- lhs_{context},
- rhs_{context},
- tau_{context},
- device_workspace_{context},
- error_(context, 1) {}
- LinearSolverTerminationType CUDADenseQR::Factorize(int num_rows,
- int num_cols,
- double* lhs,
- std::string* message) {
- factorize_result_ = LinearSolverTerminationType::FATAL_ERROR;
- lhs_.Reserve(num_rows * num_cols);
- tau_.Reserve(std::min(num_rows, num_cols));
- num_rows_ = num_rows;
- num_cols_ = num_cols;
- lhs_.CopyFromCpu(lhs, num_rows * num_cols);
- int device_workspace_size = 0;
- if (cusolverDnDgeqrf_bufferSize(context_->cusolver_handle_,
- num_rows,
- num_cols,
- lhs_.data(),
- num_rows,
- &device_workspace_size) !=
- CUSOLVER_STATUS_SUCCESS) {
- *message = "cuSolverDN::cusolverDnDgeqrf_bufferSize failed.";
- return LinearSolverTerminationType::FATAL_ERROR;
- }
- device_workspace_.Reserve(device_workspace_size);
- if (cusolverDnDgeqrf(context_->cusolver_handle_,
- num_rows,
- num_cols,
- lhs_.data(),
- num_rows,
- tau_.data(),
- reinterpret_cast<double*>(device_workspace_.data()),
- device_workspace_.size(),
- error_.data()) != CUSOLVER_STATUS_SUCCESS) {
- *message = "cuSolverDN::cusolverDnDgeqrf failed.";
- return LinearSolverTerminationType::FATAL_ERROR;
- }
- int error = 0;
- error_.CopyToCpu(&error, 1);
- if (error < 0) {
- LOG(FATAL) << "Congratulations, you found a bug in Ceres - "
- << "please report it. "
- << "cuSolverDN::cusolverDnDgeqrf fatal error. "
- << "Argument: " << -error << " is invalid.";
- // The following line is unreachable, but return failure just to be
- // pedantic, since the compiler does not know that.
- return LinearSolverTerminationType::FATAL_ERROR;
- }
- *message = "Success";
- factorize_result_ = LinearSolverTerminationType::SUCCESS;
- return LinearSolverTerminationType::SUCCESS;
- }
- LinearSolverTerminationType CUDADenseQR::Solve(const double* rhs,
- double* solution,
- std::string* message) {
- if (factorize_result_ != LinearSolverTerminationType::SUCCESS) {
- *message = "Factorize did not complete successfully previously.";
- return factorize_result_;
- }
- rhs_.CopyFromCpu(rhs, num_rows_);
- int device_workspace_size = 0;
- if (cusolverDnDormqr_bufferSize(context_->cusolver_handle_,
- CUBLAS_SIDE_LEFT,
- CUBLAS_OP_T,
- num_rows_,
- 1,
- num_cols_,
- lhs_.data(),
- num_rows_,
- tau_.data(),
- rhs_.data(),
- num_rows_,
- &device_workspace_size) !=
- CUSOLVER_STATUS_SUCCESS) {
- *message = "cuSolverDN::cusolverDnDormqr_bufferSize failed.";
- return LinearSolverTerminationType::FATAL_ERROR;
- }
- device_workspace_.Reserve(device_workspace_size);
- // Compute rhs = Q^T * rhs, assuming that lhs has already been factorized.
- // The result of factorization would have stored Q in a packed form in lhs_.
- if (cusolverDnDormqr(context_->cusolver_handle_,
- CUBLAS_SIDE_LEFT,
- CUBLAS_OP_T,
- num_rows_,
- 1,
- num_cols_,
- lhs_.data(),
- num_rows_,
- tau_.data(),
- rhs_.data(),
- num_rows_,
- reinterpret_cast<double*>(device_workspace_.data()),
- device_workspace_.size(),
- error_.data()) != CUSOLVER_STATUS_SUCCESS) {
- *message = "cuSolverDN::cusolverDnDormqr failed.";
- return LinearSolverTerminationType::FATAL_ERROR;
- }
- int error = 0;
- error_.CopyToCpu(&error, 1);
- if (error < 0) {
- LOG(FATAL) << "Congratulations, you found a bug in Ceres. "
- << "Please report it."
- << "cuSolverDN::cusolverDnDormqr fatal error. "
- << "Argument: " << -error << " is invalid.";
- }
- // Compute the solution vector as x = R \ (Q^T * rhs). Since the previous step
- // replaced rhs by (Q^T * rhs), this is just x = R \ rhs.
- if (cublasDtrsv(context_->cublas_handle_,
- CUBLAS_FILL_MODE_UPPER,
- CUBLAS_OP_N,
- CUBLAS_DIAG_NON_UNIT,
- num_cols_,
- lhs_.data(),
- num_rows_,
- rhs_.data(),
- 1) != CUBLAS_STATUS_SUCCESS) {
- *message = "cuBLAS::cublasDtrsv failed.";
- return LinearSolverTerminationType::FATAL_ERROR;
- }
- rhs_.CopyToCpu(solution, num_cols_);
- *message = "Success";
- return LinearSolverTerminationType::SUCCESS;
- }
- std::unique_ptr<CUDADenseQR> CUDADenseQR::Create(
- const LinearSolver::Options& options) {
- if (options.dense_linear_algebra_library_type != CUDA ||
- options.context == nullptr || !options.context->IsCudaInitialized()) {
- return nullptr;
- }
- return std::unique_ptr<CUDADenseQR>(new CUDADenseQR(options.context));
- }
- #endif // CERES_NO_CUDA
- } // namespace ceres::internal
|