// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_INDEXED_VIEW_H
#define EIGEN_INDEXED_VIEW_H

// IWYU pragma: private
#include "./InternalHeaderCheck.h"

namespace Eigen {

namespace internal {

template <typename XprType, typename RowIndices, typename ColIndices>
struct traits<IndexedView<XprType, RowIndices, ColIndices>> : traits<XprType> {
  enum {
    RowsAtCompileTime = int(IndexedViewHelper<RowIndices>::SizeAtCompileTime),
    ColsAtCompileTime = int(IndexedViewHelper<ColIndices>::SizeAtCompileTime),
    MaxRowsAtCompileTime = RowsAtCompileTime,
    MaxColsAtCompileTime = ColsAtCompileTime,

    XprTypeIsRowMajor = (int(traits<XprType>::Flags) & RowMajorBit) != 0,
    IsRowMajor = (MaxRowsAtCompileTime == 1 && MaxColsAtCompileTime != 1)   ? 1
                 : (MaxColsAtCompileTime == 1 && MaxRowsAtCompileTime != 1) ? 0
                                                                            : XprTypeIsRowMajor,

    RowIncr = int(IndexedViewHelper<RowIndices>::IncrAtCompileTime),
    ColIncr = int(IndexedViewHelper<ColIndices>::IncrAtCompileTime),
    InnerIncr = IsRowMajor ? ColIncr : RowIncr,
    OuterIncr = IsRowMajor ? RowIncr : ColIncr,

    HasSameStorageOrderAsXprType = (IsRowMajor == XprTypeIsRowMajor),
    XprInnerStride = HasSameStorageOrderAsXprType ? int(inner_stride_at_compile_time<XprType>::ret)
                                                  : int(outer_stride_at_compile_time<XprType>::ret),
    XprOuterstride = HasSameStorageOrderAsXprType ? int(outer_stride_at_compile_time<XprType>::ret)
                                                  : int(inner_stride_at_compile_time<XprType>::ret),

    InnerSize = XprTypeIsRowMajor ? ColsAtCompileTime : RowsAtCompileTime,
    IsBlockAlike = InnerIncr == 1 && OuterIncr == 1,
    IsInnerPannel = HasSameStorageOrderAsXprType &&
                    is_same<AllRange<InnerSize>, std::conditional_t<XprTypeIsRowMajor, ColIndices, RowIndices>>::value,

    InnerStrideAtCompileTime =
        InnerIncr < 0 || InnerIncr == DynamicIndex || XprInnerStride == Dynamic || InnerIncr == Undefined
            ? Dynamic
            : XprInnerStride * InnerIncr,
    OuterStrideAtCompileTime =
        OuterIncr < 0 || OuterIncr == DynamicIndex || XprOuterstride == Dynamic || OuterIncr == Undefined
            ? Dynamic
            : XprOuterstride * OuterIncr,

    ReturnAsScalar = is_single_range<RowIndices>::value && is_single_range<ColIndices>::value,
    ReturnAsBlock = (!ReturnAsScalar) && IsBlockAlike,
    ReturnAsIndexedView = (!ReturnAsScalar) && (!ReturnAsBlock),

    // FIXME we deal with compile-time strides if and only if we have DirectAccessBit flag,
    // but this is too strict regarding negative strides...
    DirectAccessMask = (int(InnerIncr) != Undefined && int(OuterIncr) != Undefined && InnerIncr >= 0 && OuterIncr >= 0)
                           ? DirectAccessBit
                           : 0,
    FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
    FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
    FlagsLinearAccessBit = (RowsAtCompileTime == 1 || ColsAtCompileTime == 1) ? LinearAccessBit : 0,
    Flags = (traits<XprType>::Flags & (HereditaryBits | DirectAccessMask)) | FlagsLvalueBit | FlagsRowMajorBit |
            FlagsLinearAccessBit
  };

  typedef Block<XprType, RowsAtCompileTime, ColsAtCompileTime, IsInnerPannel> BlockType;
};

template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind, bool DirectAccess>
class IndexedViewImpl;

}  // namespace internal

/** \class IndexedView
 * \ingroup Core_Module
 *
 * \brief Expression of a non-sequential sub-matrix defined by arbitrary sequences of row and column indices
 *
 * \tparam XprType the type of the expression in which we are taking the intersections of sub-rows and sub-columns
 * \tparam RowIndices the type of the object defining the sequence of row indices
 * \tparam ColIndices the type of the object defining the sequence of column indices
 *
 * This class represents an expression of a sub-matrix (or sub-vector) defined as the intersection
 * of sub-sets of rows and columns, that are themself defined by generic sequences of row indices \f$
 * \{r_0,r_1,..r_{m-1}\} \f$ and column indices \f$ \{c_0,c_1,..c_{n-1} \}\f$. Let \f$ A \f$  be the nested matrix, then
 * the resulting matrix \f$ B \f$ has \c m rows and \c n columns, and its entries are given by: \f$ B(i,j) = A(r_i,c_j)
 * \f$.
 *
 * The \c RowIndices and \c ColIndices types must be compatible with the following API:
 * \code
 * <integral type> operator[](Index) const;
 * Index size() const;
 * \endcode
 *
 * Typical supported types thus include:
 *  - std::vector<int>
 *  - std::valarray<int>
 *  - std::array<int>
 *  - Eigen::ArrayXi
 *  - decltype(ArrayXi::LinSpaced(...))
 *  - Any view/expressions of the previous types
 *  - Eigen::ArithmeticSequence
 *  - Eigen::internal::AllRange     (helper for Eigen::placeholders::all)
 *  - Eigen::internal::SingleRange  (helper for single index)
 *  - etc.
 *
 * In typical usages of %Eigen, this class should never be used directly. It is the return type of
 * DenseBase::operator()(const RowIndices&, const ColIndices&).
 *
 * \sa class Block
 */
template <typename XprType, typename RowIndices, typename ColIndices>
class IndexedView
    : public internal::IndexedViewImpl<XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind,
                                       (internal::traits<IndexedView<XprType, RowIndices, ColIndices>>::Flags &
                                        DirectAccessBit) != 0> {
 public:
  typedef typename internal::IndexedViewImpl<
      XprType, RowIndices, ColIndices, typename internal::traits<XprType>::StorageKind,
      (internal::traits<IndexedView<XprType, RowIndices, ColIndices>>::Flags & DirectAccessBit) != 0>
      Base;
  EIGEN_GENERIC_PUBLIC_INTERFACE(IndexedView)
  EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedView)

  template <typename T0, typename T1>
  IndexedView(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {}
};

namespace internal {

// Generic API dispatcher
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind, bool DirectAccess>
class IndexedViewImpl : public internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type {
 public:
  typedef typename internal::generic_xpr_base<IndexedView<XprType, RowIndices, ColIndices>>::type Base;
  typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
  typedef internal::remove_all_t<XprType> NestedExpression;
  typedef typename XprType::Scalar Scalar;

  EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedViewImpl)

  template <typename T0, typename T1>
  IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices)
      : m_xpr(xpr), m_rowIndices(rowIndices), m_colIndices(colIndices) {}

  /** \returns number of rows */
  Index rows() const { return IndexedViewHelper<RowIndices>::size(m_rowIndices); }

  /** \returns number of columns */
  Index cols() const { return IndexedViewHelper<ColIndices>::size(m_colIndices); }

  /** \returns the nested expression */
  const internal::remove_all_t<XprType>& nestedExpression() const { return m_xpr; }

  /** \returns the nested expression */
  std::remove_reference_t<XprType>& nestedExpression() { return m_xpr; }

  /** \returns a const reference to the object storing/generating the row indices */
  const RowIndices& rowIndices() const { return m_rowIndices; }

  /** \returns a const reference to the object storing/generating the column indices */
  const ColIndices& colIndices() const { return m_colIndices; }

  constexpr Scalar& coeffRef(Index rowId, Index colId) {
    return nestedExpression().coeffRef(m_rowIndices[rowId], m_colIndices[colId]);
  }

  constexpr const Scalar& coeffRef(Index rowId, Index colId) const {
    return nestedExpression().coeffRef(m_rowIndices[rowId], m_colIndices[colId]);
  }

 protected:
  MatrixTypeNested m_xpr;
  RowIndices m_rowIndices;
  ColIndices m_colIndices;
};

template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
class IndexedViewImpl<XprType, RowIndices, ColIndices, StorageKind, true>
    : public IndexedViewImpl<XprType, RowIndices, ColIndices, StorageKind, false> {
 public:
  using Base = internal::IndexedViewImpl<XprType, RowIndices, ColIndices,
                                         typename internal::traits<XprType>::StorageKind, false>;
  using Derived = IndexedView<XprType, RowIndices, ColIndices>;

  EIGEN_INHERIT_ASSIGNMENT_OPERATORS(IndexedViewImpl)

  template <typename T0, typename T1>
  IndexedViewImpl(XprType& xpr, const T0& rowIndices, const T1& colIndices) : Base(xpr, rowIndices, colIndices) {}

  Index rowIncrement() const {
    if (traits<Derived>::RowIncr != DynamicIndex && traits<Derived>::RowIncr != Undefined) {
      return traits<Derived>::RowIncr;
    }
    return IndexedViewHelper<RowIndices>::incr(this->rowIndices());
  }
  Index colIncrement() const {
    if (traits<Derived>::ColIncr != DynamicIndex && traits<Derived>::ColIncr != Undefined) {
      return traits<Derived>::ColIncr;
    }
    return IndexedViewHelper<ColIndices>::incr(this->colIndices());
  }

  Index innerIncrement() const { return traits<Derived>::IsRowMajor ? colIncrement() : rowIncrement(); }

  Index outerIncrement() const { return traits<Derived>::IsRowMajor ? rowIncrement() : colIncrement(); }

  std::decay_t<typename XprType::Scalar>* data() {
    Index row_offset = this->rowIndices()[0] * this->nestedExpression().rowStride();
    Index col_offset = this->colIndices()[0] * this->nestedExpression().colStride();
    return this->nestedExpression().data() + row_offset + col_offset;
  }

  const std::decay_t<typename XprType::Scalar>* data() const {
    Index row_offset = this->rowIndices()[0] * this->nestedExpression().rowStride();
    Index col_offset = this->colIndices()[0] * this->nestedExpression().colStride();
    return this->nestedExpression().data() + row_offset + col_offset;
  }

  EIGEN_DEVICE_FUNC constexpr Index innerStride() const noexcept {
    if (traits<Derived>::InnerStrideAtCompileTime != Dynamic) {
      return traits<Derived>::InnerStrideAtCompileTime;
    }
    return innerIncrement() * this->nestedExpression().innerStride();
  }

  EIGEN_DEVICE_FUNC constexpr Index outerStride() const noexcept {
    if (traits<Derived>::OuterStrideAtCompileTime != Dynamic) {
      return traits<Derived>::OuterStrideAtCompileTime;
    }
    return outerIncrement() * this->nestedExpression().outerStride();
  }
};

template <typename ArgType, typename RowIndices, typename ColIndices>
struct unary_evaluator<IndexedView<ArgType, RowIndices, ColIndices>, IndexBased>
    : evaluator_base<IndexedView<ArgType, RowIndices, ColIndices>> {
  typedef IndexedView<ArgType, RowIndices, ColIndices> XprType;

  enum {
    CoeffReadCost = evaluator<ArgType>::CoeffReadCost /* TODO + cost of row/col index */,

    FlagsLinearAccessBit =
        (traits<XprType>::RowsAtCompileTime == 1 || traits<XprType>::ColsAtCompileTime == 1) ? LinearAccessBit : 0,

    FlagsRowMajorBit = traits<XprType>::FlagsRowMajorBit,

    Flags = (evaluator<ArgType>::Flags & (HereditaryBits & ~RowMajorBit /*| LinearAccessBit | DirectAccessBit*/)) |
            FlagsLinearAccessBit | FlagsRowMajorBit,

    Alignment = 0
  };

  EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr) {
    EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
  }

  typedef typename XprType::Scalar Scalar;
  typedef typename XprType::CoeffReturnType CoeffReturnType;

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const {
    eigen_assert(m_xpr.rowIndices()[row] >= 0 && m_xpr.rowIndices()[row] < m_xpr.nestedExpression().rows() &&
                 m_xpr.colIndices()[col] >= 0 && m_xpr.colIndices()[col] < m_xpr.nestedExpression().cols());
    return m_argImpl.coeff(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) {
    eigen_assert(m_xpr.rowIndices()[row] >= 0 && m_xpr.rowIndices()[row] < m_xpr.nestedExpression().rows() &&
                 m_xpr.colIndices()[col] >= 0 && m_xpr.colIndices()[col] < m_xpr.nestedExpression().cols());
    return m_argImpl.coeffRef(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
    EIGEN_STATIC_ASSERT_LVALUE(XprType)
    Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
    Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
    eigen_assert(m_xpr.rowIndices()[row] >= 0 && m_xpr.rowIndices()[row] < m_xpr.nestedExpression().rows() &&
                 m_xpr.colIndices()[col] >= 0 && m_xpr.colIndices()[col] < m_xpr.nestedExpression().cols());
    return m_argImpl.coeffRef(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeffRef(Index index) const {
    Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
    Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
    eigen_assert(m_xpr.rowIndices()[row] >= 0 && m_xpr.rowIndices()[row] < m_xpr.nestedExpression().rows() &&
                 m_xpr.colIndices()[col] >= 0 && m_xpr.colIndices()[col] < m_xpr.nestedExpression().cols());
    return m_argImpl.coeffRef(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
  }

  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CoeffReturnType coeff(Index index) const {
    Index row = XprType::RowsAtCompileTime == 1 ? 0 : index;
    Index col = XprType::RowsAtCompileTime == 1 ? index : 0;
    eigen_assert(m_xpr.rowIndices()[row] >= 0 && m_xpr.rowIndices()[row] < m_xpr.nestedExpression().rows() &&
                 m_xpr.colIndices()[col] >= 0 && m_xpr.colIndices()[col] < m_xpr.nestedExpression().cols());
    return m_argImpl.coeff(m_xpr.rowIndices()[row], m_xpr.colIndices()[col]);
  }

 protected:
  evaluator<ArgType> m_argImpl;
  const XprType& m_xpr;
};

// Catch assignments to an IndexedView.
template <typename ArgType, typename RowIndices, typename ColIndices>
struct evaluator_assume_aliasing<IndexedView<ArgType, RowIndices, ColIndices>> {
  static const bool value = true;
};

}  // end namespace internal

}  // end namespace Eigen

#endif  // EIGEN_INDEXED_VIEW_H
