// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_SOLVERBASE_H
#define EIGEN_SOLVERBASE_H

// IWYU pragma: private
#include "./InternalHeaderCheck.h"

namespace Eigen {

namespace internal {

template <typename Derived>
struct solve_assertion {
  template <bool Transpose_, typename Rhs>
  static void run(const Derived& solver, const Rhs& b) {
    solver.template _check_solve_assertion<Transpose_>(b);
  }
};

template <typename Derived>
struct solve_assertion<Transpose<Derived>> {
  typedef Transpose<Derived> type;

  template <bool Transpose_, typename Rhs>
  static void run(const type& transpose, const Rhs& b) {
    internal::solve_assertion<internal::remove_all_t<Derived>>::template run<true>(transpose.nestedExpression(), b);
  }
};

template <typename Scalar, typename Derived>
struct solve_assertion<CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived>>> {
  typedef CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived>> type;

  template <bool Transpose_, typename Rhs>
  static void run(const type& adjoint, const Rhs& b) {
    internal::solve_assertion<internal::remove_all_t<Transpose<Derived>>>::template run<true>(
        adjoint.nestedExpression(), b);
  }
};
}  // end namespace internal

/** \class SolverBase
 * \brief A base class for matrix decomposition and solvers
 *
 * \tparam Derived the actual type of the decomposition/solver.
 *
 * Any matrix decomposition inheriting this base class provide the following API:
 *
 * \code
 * MatrixType A, b, x;
 * DecompositionType dec(A);
 * x = dec.solve(b);             // solve A   * x = b
 * x = dec.transpose().solve(b); // solve A^T * x = b
 * x = dec.adjoint().solve(b);   // solve A'  * x = b
 * \endcode
 *
 * \warning Currently, any other usage of transpose() and adjoint() are not supported and will produce compilation
 * errors.
 *
 * \sa class PartialPivLU, class FullPivLU, class HouseholderQR, class ColPivHouseholderQR, class FullPivHouseholderQR,
 * class CompleteOrthogonalDecomposition, class LLT, class LDLT, class SVDBase
 */
template <typename Derived>
class SolverBase : public EigenBase<Derived> {
 public:
  typedef EigenBase<Derived> Base;
  typedef typename internal::traits<Derived>::Scalar Scalar;
  typedef Scalar CoeffReturnType;

  template <typename Derived_>
  friend struct internal::solve_assertion;

  ComputationInfo info() const {
    // CRTP static dispatch: Calls the 'info()' method on the derived class.
    // Derived must implement 'ComputationInfo info() const'.
    // If not implemented, name lookup falls back to this base method, causing
    // infinite recursion (detectable by -Winfinite-recursion).
    return derived().info();
  }

  enum {
    RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime,
    ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime,
    SizeAtCompileTime = (internal::size_of_xpr_at_compile_time<Derived>::ret),
    MaxRowsAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime,
    MaxColsAtCompileTime = internal::traits<Derived>::MaxColsAtCompileTime,
    MaxSizeAtCompileTime = internal::size_at_compile_time(internal::traits<Derived>::MaxRowsAtCompileTime,
                                                          internal::traits<Derived>::MaxColsAtCompileTime),
    IsVectorAtCompileTime =
        internal::traits<Derived>::MaxRowsAtCompileTime == 1 || internal::traits<Derived>::MaxColsAtCompileTime == 1,
    NumDimensions = int(MaxSizeAtCompileTime) == 1 ? 0
                    : bool(IsVectorAtCompileTime)  ? 1
                                                   : 2
  };

  /** Default constructor */
  SolverBase() {}

  ~SolverBase() {}

  using Base::derived;

  /** \returns an expression of the solution x of \f$ A x = b \f$ using the current decomposition of A.
   */
  template <typename Rhs>
  inline const Solve<Derived, Rhs> solve(const MatrixBase<Rhs>& b) const {
    internal::solve_assertion<internal::remove_all_t<Derived>>::template run<false>(derived(), b);
    return Solve<Derived, Rhs>(derived(), b.derived());
  }

  /** \internal the return type of transpose() */
  typedef Transpose<const Derived> ConstTransposeReturnType;
  /** \returns an expression of the transposed of the factored matrix.
   *
   * A typical usage is to solve for the transposed problem A^T x = b:
   * \code x = dec.transpose().solve(b); \endcode
   *
   * \sa adjoint(), solve()
   */
  inline const ConstTransposeReturnType transpose() const { return ConstTransposeReturnType(derived()); }

  /** \internal the return type of adjoint() */
  typedef std::conditional_t<NumTraits<Scalar>::IsComplex,
                             CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const ConstTransposeReturnType>,
                             const ConstTransposeReturnType>
      AdjointReturnType;
  /** \returns an expression of the adjoint of the factored matrix
   *
   * A typical usage is to solve for the adjoint problem A' x = b:
   * \code x = dec.adjoint().solve(b); \endcode
   *
   * For real scalar types, this function is equivalent to transpose().
   *
   * \sa transpose(), solve()
   */
  inline const AdjointReturnType adjoint() const { return AdjointReturnType(derived().transpose()); }

 protected:
  template <bool Transpose_, typename Rhs>
  void _check_solve_assertion(const Rhs& b) const {
    EIGEN_ONLY_USED_FOR_DEBUG(b);
    eigen_assert(derived().m_isInitialized && "Solver is not initialized.");
    eigen_assert((Transpose_ ? derived().cols() : derived().rows()) == b.rows() &&
                 "SolverBase::solve(): invalid number of rows of the right hand side matrix b");
  }
};

namespace internal {

template <typename Derived>
struct generic_xpr_base<Derived, MatrixXpr, SolverStorage> {
  typedef SolverBase<Derived> type;
};

}  // end namespace internal

}  // end namespace Eigen

#endif  // EIGEN_SOLVERBASE_H
