#pragma once
#include "../common/std_headers.h"

#include <random>

#include "../common/packetbuf.h"
#include "../common/str_utils.h"

struct MinMax {
  MinMax()
    : min(0.0), max(0.0)
  {
  }
  explicit MinMax(double _min, double _max)
    : min(_min), max(_max)
  {
  }
  double min;
  double max;
};


static inline double normangle(double x)
{
  if (x > M_PI) {
    return fmod((x + M_PI), M_2PI) - M_PI;
  }
  else if (x < -M_PI) {
    return -(fmod((-x + M_PI), M_2PI) - M_PI);
  }
  else {
    return x;
  }
}

static inline double normangle2(double x)
{
  if (x > 3*M_PI) {
    return fmod((x + M_PI), M_2PI) - M_PI;
  }
  else if (x < -3*M_PI) {
    return -(fmod((-x + M_PI), M_2PI) - M_PI);
  }
  else {
    return x;
  }
}

static inline float normangle(float x)
{
  if (x > M_PI) {
    return fmod((x + M_PI), M_2PI) - M_PI;
  }
  else if (x < -M_PI) {
    return -(fmod((-x + M_PI), M_2PI) - M_PI);
  }
  else {
    return x;
  }
}

static inline float normangle2(float x)
{
  if (x > 3*M_PI) {
    return fmod((x + M_PI), M_2PI) - M_PI;
  }
  else if (x < -3*M_PI) {
    return -(fmod((-x + M_PI), M_2PI) - M_PI);
  }
  else {
    return x;
  }
}

static inline double sqr(double x)
{
  return x * x;
}

static inline double cube(double x)
{
  return x * x * x;
}

static inline double easeInRaisedCos(double x)
{
  if (x <= 0.0) return 0.0;
  if (x >= 1.0) return 1.0;
  return (1 - cos(x * M_PI)) * 0.5;
}

static inline double linearComb(double aCoeff, double const &a, double bCoeff, double const &b)
{
  return aCoeff * a + bCoeff * b;
}
static inline float linearComb(double aCoeff, float const &a, double bCoeff, float const &b)
{
  return aCoeff * a + bCoeff * b;
}
static inline S64 linearComb(double aCoeff, S64 const &a, double bCoeff, S64 const &b)
{
  return S64(aCoeff * (double)a + bCoeff * (double)b);
}
static inline S32 linearComb(double aCoeff, S32 const &a, double bCoeff, S32 const &b)
{
  return S32(aCoeff * (double)a + bCoeff * (double)b);
}
static inline U64 linearComb(double aCoeff, U64 const &a, double bCoeff, U64 const &b)
{
  return U64(aCoeff * (double)a + bCoeff * (double)b);
}
static inline U32 linearComb(double aCoeff, U32 const &a, double bCoeff, U32 const &b)
{
  return U32(aCoeff * (double)a + bCoeff * (double)b);
}
static inline std::complex<double>
linearComb(double aCoeff, std::complex<double> const &a, double bCoeff, std::complex<double> const &b)
{
  return aCoeff * a + bCoeff * b;
}

static inline string linearComb(double aCoeff, string const &a, double bCoeff, string const &b)
{
  return aCoeff > bCoeff ? a : b;
}

#ifdef USE_EIGEN3
template <typename _Scalar, int _Rows, int _Cols>
Eigen::Matrix<_Scalar, _Rows, _Cols> linearComb(
    double aCoeff,
    Eigen::Matrix<_Scalar, _Rows, _Cols> const &a,
    double bCoeff,
    Eigen::Matrix<_Scalar, _Rows, _Cols> const &b)
{
  if constexpr (Eigen::NumTraits<_Scalar>::IsInteger) {
    return aCoeff > bCoeff ? a : b;
  }
  else {
    return aCoeff * a + bCoeff * b;
  }
}
#endif

template<typename T, typename U>
inline T lerp(T a, T b, U t)
{
  return a + t * (b - a);
}


template <typename T>
map<string, T> linearComb(double aCoeff, map<string, T> const &a, double bCoeff, map<string, T> const &b)
{
  return aCoeff > bCoeff ? a : b;
}

template <typename T>
vector<T> linearComb(double aCoeff, vector<T> const &a, double bCoeff, vector<T> const &b)
{
  auto retSize = max(a.size(), b.size());
  vector<T> ret(retSize);
  for (size_t i = 0; i < retSize; i++) {
    ret[i] = linearComb(aCoeff, i < a.size() ? a[i] : T(), bCoeff, i < b.size() ? b[i] : T());
  }
  return ret;
}

static inline double linearMetric(double const &a, double const &b)
{
  return a * b;
}
static inline double linearMetric(float const &a, float const &b)
{
  return (double)a * (double)b;
}
static inline double linearMetric(S64 const &a, S64 const &b)
{
  return (double)a * (double)b;
}
static inline double linearMetric(S32 const &a, S32 const &b)
{
  return (double)a * (double)b;
}
static inline double linearMetric(U64 const &a, U64 const &b)
{
  return (double)a * (double)b;
}
static inline double linearMetric(U32 const &a, U32 const &b)
{
  return (double)a * (double)b;
}
static inline double linearMetric(std::complex<double> const &a, std::complex<double> const &b)
{
  return real(a * b);
}

static inline double linearMetric(string const &a, string const &b)
{
  return 0.0;
}

#ifdef USE_EIGEN3
template <typename _Scalar, int _Rows, int _Cols>
double linearMetric(Eigen::Matrix<_Scalar, _Rows, _Cols > const &a, Eigen::Matrix< _Scalar, _Rows, _Cols> const &b)
{
  return (a.array() * b.array()).matrix().norm();
}
#endif

template <typename T>
double linearMetric(map<string, T> const &a, map<string, T> const &b)
{
  set<string> keys;
  for (auto const &it : a) {
    keys.insert(it.first);
  }
  for (auto const &it : b) {
    keys.insert(it.first);
  }
  double ret = 0.0;
  for (auto const &it : keys) {
    auto ait = a.find(it);
    auto bit = b.find(it);
    ret += linearMetric((ait == a.end() ? T() : ait->second), (bit == b.end() ? T() : bit->second));
  }
  return 0.0;
}

template <typename T>
double linearMetric(vector<T> const &a, vector<T> const &b)
{
  auto size = max(a.size(), b.size());
  double ret = 0.0;
  for (size_t i = 0; i < size; i++) {
    ret += linearMetric(i < a.size() ? a[i] : T(), i < b.size() ? b[i] : T());
  }
  return ret;
}

/*
  randomSample
*/

static inline double randomSample(double const &pdist, std::default_random_engine &rand)
{
  return pdist;
}

static inline S32 randomSample(S32 const &pdist, std::default_random_engine &rand)
{
  return pdist;
}

static inline S64 randomSample(S64 const &pdist, std::default_random_engine &rand)
{
  return pdist;
}

static inline U32 randomSample(U32 const &pdist, std::default_random_engine &rand)
{
  return pdist;
}

static inline U64 randomSample(U64 const &pdist, std::default_random_engine &rand)
{
  return pdist;
}

static inline bool randomSample(bool const &pdist, std::default_random_engine &rand)
{
  return pdist;
}

static inline string randomSample(string const &pdist, std::default_random_engine &rand)
{
  return pdist;
}

template <typename T>
vector<T> randomSample(vector<T> const &pdist, std::default_random_engine &rand)
{
  vector<T> ret(pdist.size());
  size_t i = 0;
  for (auto const &it : pdist) {
    ret[i] = randomSample(it, rand);
    i++;
  }
  return ret;
}

template <typename T>
shared_ptr<T> randomSample(shared_ptr<T> const &pdist, std::default_random_engine &rand)
{
  if (!pdist) return nullptr;
  return make_shared<T>(randomSample(*pdist, rand));
}

template <typename K, typename V>
map<K, V> randomSample(map<K, V> const &pdist, std::default_random_engine &rand)
{
  map<K, V> ret;
  for (auto const &it : pdist) {
    ret[it.first] = randomSample(it.second, rand);
  }
  return ret;
}

#ifdef USE_EIGEN3
template <typename _Scalar, int _Rows, int _Cols>
Eigen::Matrix<_Scalar, _Rows, _Cols>
randomSample(Eigen::Matrix<_Scalar, _Rows, _Cols> const &pdist, std::default_random_engine &rand)
{
  return pdist;
}
#endif

/*
  hasNaN returns true if there's a NaN somewhere.
  Null pointers or empty data structures return false.
*/

template <typename T>
bool hasNaN(shared_ptr<T> const &a);

static inline bool hasNaN(double const &a)
{
  return isnan(a);
}
static inline bool hasNaN(float const &a)
{
  return isnan(a);
}
static inline bool hasNaN(S64 const &a)
{
  return false;
}
static inline bool hasNaN(S32 const &a)
{
  return false;
}
static inline bool hasNaN(U64 const &a)
{
  return false;
}
static inline bool hasNaN(U32 const &a)
{
  return false;
}

static inline bool hasNaN(string const &a)
{
  return false;
}

#ifdef USE_EIGEN3
template <typename _Scalar, int _Rows, int _Cols>
bool hasNaN(Eigen::Matrix<_Scalar, _Rows, _Cols> const &a)
{
  return a.hasNaN();
}
#endif

template <typename T>
bool hasNaN(map<string, T> const &a)
{
  bool out = false;
  for (auto &it : a) {
    out = out || hasNaN(it.second);
  }
  return out;
}

template <typename T>
bool hasNaN(vector<T> const &a)
{
  bool out = false;
  for (size_t i = 0; i < a.size(); i++) {
    out = out || hasNaN(a[i]);
  }
  return out;
}

template <typename T>
bool hasNaN(shared_ptr<T> const &a)
{
  if (!a) return false;
  return hasNaN(*a);
}

template <typename T>
bool applyNamedChange(T &obj, char const *key, double weight, double value)
{
  return false;
}

template <>
bool applyNamedChange(double &obj, char const *key, double weight, double value);

template <>
bool applyNamedChange(float &obj, char const *key, double weight, double value);

template <>
bool applyNamedChange(U32 &obj, char const *key, double weight, double value);

template <>
bool applyNamedChange(S32 &obj, char const *key, double weight, double value);

template <>
bool applyNamedChange(U64 &obj, char const *key, double weight, double value);

template <>
bool applyNamedChange(S64 &obj, char const *key, double weight, double value);

template <>
bool applyNamedChange(bool &obj, char const *key, double weight, double value);

template <>
bool applyNamedChange(string &obj, char const *key, double weight, double value);

template <>
bool applyNamedChange(std::complex<double> &obj, char const *key, double weight, double value);

#ifdef USE_EIGEN3
template <typename _Scalar, int _Rows, int _Cols>
bool applyNamedChange(Eigen::Matrix<_Scalar, _Rows, _Cols> &obj, char const *key, double weight, double value)
{
  if (key[0] == '[') {
    char *endp = nullptr;
    long index = strtol(key + 1, &endp, 10);
    if (endp[0] == ']' && endp[1] == '.') {
      applyNamedChange(obj(index), endp + 2, weight, value);
      return true;
    }
  }
  return false;
}
#endif

template <typename T>
bool applyNamedChange(vector<T> &obj, char const *key, double weight, double value)
{
  if (key[0] == '[') {
    char *endp = nullptr;
    long index = strtol(key + 1, &endp, 10);
    if (endp[0] == ']' && endp[1] == '.') {
      return applyNamedChange(obj[index], endp + 2, weight, value);
    }
  }
  return false;
}

template <typename T>
bool applyNamedChange(map<string, T> &obj, char const *key, double weight, double value)
{
  char const *dot = strchr(key, '.');
  if (dot) {
    string keyFirst = string(key, dot);
    return applyNamedChange(obj[keyFirst], dot + 1, weight, value);
  }
  return false;
}

template <typename T>
bool applyNamedChange(shared_ptr<T> &obj, char const *key, double weight, double value)
{
  return applyNamedChange(*obj, key, weight, value);
}

extern template class std::vector<double>;

namespace packetio {

#ifdef USE_EIGEN3
template <typename Element, int Rows, int Cols>
string packet_get_typetag(Eigen::Matrix<Element, Rows, Cols> const &x)
{
  static string typetag(
      "Eigen::Matrix<"s + packet_get_typetag(Element()) + ", "s + to_string(Rows) + ", "s + to_string(Cols)
      + ">@1"s);
  return typetag;
}

template <typename Element, int Rows, int Cols>
void packet_wr_value(packet &p, Eigen::Matrix<Element, Rows, Cols> const &x)
{
  p.add((S64)x.rows());
  p.add((S64)x.cols());
  p.add_bytes(reinterpret_cast<u_char const *>(x.data()), x.size() * sizeof(Element));
}

template <typename Element, int Rows, int Cols>
void packet_rd_value(packet &p, Eigen::Matrix<Element, Rows, Cols> &x)
{
  S64 rows = 0, cols = 0;
  p.get(rows);
  p.get(cols);
  if (rows < 0 || rows > 10000000 || cols < 0 || cols > 10000000 || rows * cols > 1000000000) {
    throw packet_rd_type_err(
        "implausible array size", "rows="s + to_string(rows) + ", cols="s + to_string(cols));
  }
  x.resize(rows, cols);
  p.get_bytes(reinterpret_cast<char *>(x.data()), x.size() * sizeof(Element));
}

template <typename Element, int Rows, int Cols>
std::function<void(packet &, Eigen::Matrix<Element, Rows, Cols> &)>
packet_rd_value_compat(Eigen::Matrix<Element, Rows, Cols> const &x, string const &typetag)
{
  if (typetag == packet_get_typetag(x))
    return static_cast<void (*)(packet &, Eigen::Matrix<Element, Rows, Cols> &)>(packetio::packet_rd_value);
  return nullptr;
}
#endif

} // namespace packetio


