#include "./chunky.h"
#if defined(__EMSCRIPTEN__)
#include <malloc.h>
#endif

AllocTracker trackChunks("chunks");
AllocTracker trackChunkBytes("chunkBytes");

ChunkyAlloc::ChunkyAlloc()
  : curChunkPtr(nullptr),
    curChunkAvail(0)
{
  trackChunks.a();
}

ChunkyAlloc::~ChunkyAlloc()
{
  trackChunks.f();
  trackChunkBytes.f(totalChunkAlloc);
  for (auto ptr : chunks) {
    free(ptr);
  }
}

ChunkyAlloc::ChunkyAlloc(ChunkyAlloc &&other)
  : curChunkPtr(std::move(other.curChunkPtr)),
    curChunkAvail(std::move(other.curChunkAvail)),
    chunks(std::move(other.chunks)),
    objects(std::move(other.objects))
{
  trackChunks.a();
  other.curChunkPtr = nullptr;
  other.curChunkAvail = 0;
  other.chunks.clear();
  other.objects.clear();
}

ChunkyAlloc & ChunkyAlloc::operator= (ChunkyAlloc &&other)
{
  curChunkPtr = std::move(other.curChunkPtr); other.curChunkPtr = nullptr;
  curChunkAvail = std::move(other.curChunkAvail); other.curChunkAvail = 0;
  for (auto &it : other.chunks) {
    chunks.push_back(it);
  }
  other.chunks.clear();
  for (auto &it : other.objects) {
    objects.push_back(it);
  }
  other.objects.clear();
  return *this;
}

void *ChunkyAlloc::allocFull(size_t align, size_t s)
{
  if (s > curChunkAvail) {
    size_t chunkSize = chunks.empty() ? 16384: 262144;
    while (s*4 + align*4 + 16 > chunkSize) chunkSize *= 2;
    auto chunk = (U8 *)aligned_alloc(64, chunkSize);
    trackChunkBytes.a(chunkSize);
    chunks.push_back(chunk);
    totalChunkAlloc += chunkSize;
    curChunkPtr = chunk;
    curChunkAvail = chunkSize;
    
    auto misalign = (size_t)curChunkPtr & (align - 1);
    if (misalign) {
      curChunkPtr += align - misalign;
      curChunkAvail -= align - misalign;
    }
  }

  if (s > curChunkAvail) throw logic_error("chunky overrun");

  auto ret = curChunkPtr;
  curChunkPtr += s;
  curChunkAvail -= s;
  return (void *)ret;
}


size_t ChunkyAlloc::totalAlloc() const
{
  return totalChunkAlloc - curChunkAvail;
}

static vector<AllocTracker *> *allTrackers;

AllocTracker::AllocTracker(char const *_name)
  : name(_name)
{
  if (!allTrackers) allTrackers = new vector<AllocTracker *>();
  allTrackers->push_back(this);
}

AllocTracker::~AllocTracker()
{
  if (allTrackers) {
    for (auto &it : *allTrackers) {
      if (it == this) it = nullptr;
    }
  }
}

string AllocTracker::fmtStats()
{
  ostringstream s;
  printStats(s);
  return s.str();
}

void AllocTracker::printStats(ostream &s)
{
#if defined(__EMSCRIPTEN__)
  s << "Alloc: total " << sbrk(0) << "\n";
  auto mi = mallinfo();
  s << "  arena=" << mi.arena << " ordblks=" << mi.ordblks << 
    " hblks=" << mi.hblks << " hblkhd=" << mi.hblkhd <<
    " uordblks=" << mi.uordblks << " fordblks=" << mi.fordblks << "\n";
#else
  s << "Alloc:\n";
#endif

  if (allTrackers) {
    for (auto it : *allTrackers) {
      if (it) {
        s << "  " << it->name << " : " << it->nAlloc << " / " << it->nTotalAlloc << "\n";
      }
    }
  }
}

void AllocTracker::barf(int64_t newTotal)
{
  cerr << "AllocTracker: negative alloc at " << newTotal << " / " << nTotalAlloc << "\n";
}
