#include "./haltonseq.h"

// The 45 odd primes up to 200. It's hard to imagine wanting more than 45 axes
static vector<u_int> haltonAxes{3,   5,   7,   11,  13,  17,  19,  23,  27,  29,  31,  37,  41,  43,  47,  53,
                                  59,  61,  67,  71,  73,  79,  83,  89,  97,  101, 103, 107, 109, 113, 127, 131,
                                  137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199};

/*
  Return the (i)th number in the halton sequence of the given (radix)
  This approximates a uniform distribution in [0..1]
*/
double unipolarHaltonAxis(u_int i, u_int radix)
{
  if (i == 0) return 0.0;
  int digit = int(i % radix);

  double digitValue = digit;
  double placeValue = 1.0 / radix;

  return (digitValue + unipolarHaltonAxis(i / radix, radix)) * placeValue;
}

/*
  Return a (ncols)-tuple of the (i)th halton sequence
*/
vector<R> unipolarHaltonRow(u_int i, size_t nCols)
{
  assert(nCols <= haltonAxes.size());
  vector<R> ret(nCols);
  for (size_t ci = 0; ci < nCols; ci++) {
    ret[ci] = unipolarHaltonAxis(i, haltonAxes[ci]);
  }
  return ret;
}

/*
  Return the (i)th number in the bipolar halton sequence of the given (radix)
  This approximates a uniform distribution in [-1..1]
*/
double bipolarHaltonAxis(u_int i, u_int radix)
{
  if (i == 0) return 0.0;
  int digit = int(i % radix);

  double digitValue = (1 - (digit % 2) * 2) * ((digit + 1) / 2) * 2.0;
  double placeValue = 1.0 / radix;

  return (digitValue + bipolarHaltonAxis(i / radix, radix)) * placeValue;
}

vector<R> bipolarHaltonRow(u_int i, size_t nCols)
{
  assert(nCols <= haltonAxes.size());
  vector<R> ret(nCols);
  for (size_t ci = 0; ci < nCols; ci++) {
    ret[ci] = bipolarHaltonAxis(i, haltonAxes[ci]);
  }
  return ret;
}

/*
  Transform two uniformly distributed variables on [0..1] to two normally distributed variables
  with mean 0 and variance 1.
  See http://en.wikipedia.org/wiki/Box-Muller_transform
*/
static std::complex<double> boxMullerTransform(double u1, double u2)
{
  double factor = sqrt(-2.0 * log(u1));
  double theta = 2.0 * M_PI * u2;
  return std::complex<double>(cos(theta) * factor, sin(theta) * factor);
}

vector<R> gaussianHaltonRow(u_int i, size_t nCols)
{
  assert(nCols + 1 <= haltonAxes.size());
  vector<R> ret(nCols);
  for (size_t ci = 0; ci < nCols; ci += 2) {
    double u1 = unipolarHaltonAxis(i + 1, haltonAxes[ci + 0]);
    double u2 = unipolarHaltonAxis(i + 1, haltonAxes[ci + 1]);
    std::complex<double> z = boxMullerTransform(u1, u2);
    ret[ci] = z.real();
    if (ci + 1 < nCols) ret[ci + 1] = z.imag();
  }
  return ret;
}

// The R2 minimum discrepancy sequence.
// http://extremelearning.com.au/unreasonable-effectiveness-of-quasirandom-sequences/

double r2Gamma(u_int d)
{
  double x = 1.0;
  for (int i = 0; i < 20; i++) {
    x = x - (pow(x, d + 1) - x - 1.0) / (double(d + 1) * pow(x, d) - 1.0);
  }
  return x;
}

vector<R> r2Factors(u_int d)
{
  vector<R> alpha(d);
  double g = r2Gamma(d);
  for (u_int ci = 0; ci < d; ci++) {
    alpha[ci] = fmod(pow(1.0 / g, ci + 1), 1.0);
  }
  return alpha;
}

vector<R> unipolarR2Row(u_int i, u_int d)
{
  auto alpha = r2Factors(d);
  vector<R> x(d);
  for (size_t ci = 0; ci < d; ci++) {
    x[ci] = frac(0.5 + R(i+1) * alpha[ci]);
  }
  return x;
}

vector<R> bipolarR2Row(u_int i, u_int d)
{
  vector<R> ret;
  for (auto &it : unipolarR2Row(i, d)) {
    ret.push_back(2.0 * it - 1.0);
  }
  return ret;
}

vector<R> gaussianR2Row(u_int i, size_t d)
{
  vector<R> ret(d);
  vector<R> x = unipolarR2Row(i, (d + 1) / 2 * 2);

  for (size_t ci = 0; ci < d; ci += 2) {
    double u1 = x[ci];
    double u2 = x[ci + 1];
    std::complex<double> z = boxMullerTransform(u1, u2);
    ret[ci] = z.real();
    if (ci + 1 < d) ret[ci + 1] = z.imag();
  }
  return ret;
}
