#pragma once
/**
 * Simple serializer / deserializer inspired by protobuf
 *
 * Structure:
 * - 4 bytes magic number   (GLUE_MAGIC)
 * - 4 bytes version number (GLUE_VERSION)
 * - 8 bytes message prototype ID
 * - 4 bytes message length, unsigned number
 * - message data
 *
 * Each field in the message is encoded as:
 * - 4 bytes data type
 * - 4 bytes size, unsigned number (only for array and string)
 * - data
 */

#include <cstdint>
#include <string>
#include <vector>
#include <functional>

// reserved for future, do not edit the version number for now
#define GLUE_VERSION 1

#define GLUE_MAGIC 0x45554c47 // "GLUE"
#define GLUE_PROTO_ID_LEN 8

#ifndef GLUE_DEBUG
#define GLUE_DEBUG(...)
#endif

#define BITS_TO_BYTES(x) ((x) / 8)

// Data types
// Note: we're doing polymorphism using enum to prevent using virtual functions

enum glue_dtype
{
  GLUE_DTYPE_NULL,
  GLUE_DTYPE_BOOL,
  GLUE_DTYPE_INT,
  GLUE_DTYPE_FLOAT,
  GLUE_DTYPE_STRING,
  GLUE_DTYPE_RAW,
  GLUE_DTYPE_ARRAY_BOOL,
  GLUE_DTYPE_ARRAY_INT,
  GLUE_DTYPE_ARRAY_FLOAT,
  GLUE_DTYPE_ARRAY_STRING,
  GLUE_DTYPE_ARRAY_RAW,
};

using glue_data_ptr = const char *;

struct glue_outbuf
{
  std::vector<char> data;
  glue_outbuf()
  {
    data.reserve(1024);
  }
  void append(const char *val, size_t size)
  {
    GLUE_DEBUG(" << offset = 0x%02zx\n", data.size());
    data.insert(data.end(), val, val + size);
  }
  void append_str(const std::string &val)
  {
    GLUE_DEBUG(" << offset = 0x%02zx\n", data.size());
    data.insert(data.end(), val.begin(), val.end());
  }
  void append_u32(uint32_t val)
  {
    GLUE_DEBUG(" << offset = 0x%02zx\n", data.size());
    data.insert(data.end(), (char *)&val, (char *)&val + BITS_TO_BYTES(32));
  }
  void append_i32(int32_t val)
  {
    GLUE_DEBUG(" << offset = 0x%02zx\n", data.size());
    data.insert(data.end(), (char *)&val, (char *)&val + BITS_TO_BYTES(32));
  }
  void append_f32(float val)
  {
    GLUE_DEBUG(" << offset = 0x%02zx\n", data.size());
    data.insert(data.end(), (char *)&val, (char *)&val + BITS_TO_BYTES(32));
  }
  void clear() {
    data.clear();
    data.reserve(1024);
  }
};

struct glue_inbuf
{
  glue_data_ptr base;
  glue_data_ptr cur;
  glue_inbuf(glue_data_ptr data) : base(data), cur(data) {}
  uint32_t read_u32()
  {
    GLUE_DEBUG(" >> offset = 0x%02zx\n", cur - base);
    uint32_t val = *(uint32_t *)cur;
    cur += BITS_TO_BYTES(32);
    return val;
  }
  int32_t read_i32()
  {
    GLUE_DEBUG(" >> offset = 0x%02zx\n", cur - base);
    int32_t val = *(int32_t *)cur;
    cur += BITS_TO_BYTES(32);
    return val;
  }
  float read_f32()
  {
    GLUE_DEBUG(" >> offset = 0x%02zx\n", cur - base);
    float val = *(float *)cur;
    cur += BITS_TO_BYTES(32);
    return val;
  }
  std::string read_str(uint32_t size)
  {
    GLUE_DEBUG(" >> offset = 0x%02zx\n", cur - base);
    std::string val(cur, size);
    cur += size;
    return val;
  }
  std::vector<char> read_raw(uint32_t size)
  {
    GLUE_DEBUG(" >> offset = 0x%02zx\n", cur - base);
    std::vector<char> val;
    val.reserve(size);
    val.insert(val.end(), cur, cur + size);
    cur += size;
    return val;
  }

  // for array
  void read(uint32_t &out) { out = read_u32(); }
  void read(int32_t &out) { out = read_i32(); }
  void read(float &out) { out = read_f32(); }
  void read(std::string &out)
  {
    uint32_t size = read_u32();
    out = read_str(size);
  }
  void read(std::vector<char> &out)
  {
    uint32_t size = read_u32();
    out = read_raw(size);
  }
  void read(std::vector<uint8_t> &out)
  {
    uint32_t size = read_u32();
    auto tmp = read_raw(size);
    out.assign((uint8_t*)tmp.data(), (uint8_t*)tmp.data() + tmp.size());
  }
};

struct glue_type_base;
struct glue_handler
{
  const char *name = nullptr;
  std::vector<glue_type_base *> fields;

  glue_handler(const char *name) : name(name) {}
  void register_field(glue_type_base *field)
  {
    fields.push_back(field);
  };
  void serialize(glue_outbuf &output);
  void deserialize(glue_inbuf &input);
};

struct glue_type_base
{
  const char *name = nullptr;
  glue_dtype dtype = GLUE_DTYPE_NULL;
  glue_handler handler;

  glue_type_base() = delete;
  glue_type_base(const char *name, glue_handler &handler, glue_dtype dtype) : name(name), handler(handler), dtype(dtype)
  {
    handler.register_field(this);
  }
  bool is_null() { return dtype == GLUE_DTYPE_NULL; }
  bool not_null() { return !is_null(); }
  void set_null() { dtype = GLUE_DTYPE_NULL; }
  bool parse_type(glue_inbuf &input)
  {
    dtype = (glue_dtype)input.read_u32();
    if (dtype == GLUE_DTYPE_NULL)
    {
      GLUE_DEBUG(" >> null\n");
      return true;
    }
    return false;
  }
};

struct glue_bool : glue_type_base
{
  bool value = false;

  glue_bool(const char *name, glue_handler &handler) : glue_type_base(name, handler, GLUE_DTYPE_BOOL) {}
  void parse(glue_inbuf &input)
  {
    if (parse_type(input))
      return;
    value = (bool)input.read_u32();
    GLUE_DEBUG(" >> bool %d\n", value);
  }
  void serialize(glue_outbuf &output)
  {
    GLUE_DEBUG(" << bool %d\n", value);
    output.append_u32(dtype);
    output.append_u32(value);
  }
};

struct glue_int : glue_type_base
{
  int32_t value = 0;

  glue_int(const char *name, glue_handler &handler) : glue_type_base(name, handler, GLUE_DTYPE_INT) {}
  void parse(glue_inbuf &input)
  {
    if (parse_type(input))
      return;
    value = input.read_i32();
    GLUE_DEBUG(" >> int %d\n", value);
  }
  void serialize(glue_outbuf &output)
  {
    GLUE_DEBUG(" << int %d\n", value);
    output.append_u32(dtype);
    output.append_i32(value);
  }
};

struct glue_float : glue_type_base
{
  float value = 0.0f;

  glue_float(const char *name, glue_handler &handler) : glue_type_base(name, handler, GLUE_DTYPE_FLOAT) {}
  void parse(glue_inbuf &input)
  {
    if (parse_type(input))
      return;
    value = input.read_f32();
    GLUE_DEBUG(" >> float %f\n", value);
  }
  void serialize(glue_outbuf &output)
  {
    GLUE_DEBUG(" << float %f\n", value);
    output.append_u32(dtype);
    output.append_f32(value);
  }
};

struct glue_str : glue_type_base
{
  std::string value;

  glue_str(const char *name, glue_handler &handler) : glue_type_base(name, handler, GLUE_DTYPE_STRING) {}
  void parse(glue_inbuf &input)
  {
    if (parse_type(input))
      return;
    uint32_t size = input.read_u32();
    value = input.read_str(size);
    GLUE_DEBUG(" >> string %s\n", value.c_str());
  }
  void serialize(glue_outbuf &output)
  {
    GLUE_DEBUG(" << string %s\n", value.c_str());
    output.append_u32(dtype);
    output.append_u32(value.size());
    output.append_str(value);
  }
};

struct glue_raw : glue_type_base
{
  std::vector<char> buf;

  glue_raw(const char *name, glue_handler &handler) : glue_type_base(name, handler, GLUE_DTYPE_RAW) {}
  void parse(glue_inbuf &input)
  {
    if (parse_type(input))
      return;
    uint32_t size = input.read_u32();
    buf = input.read_raw(size);
    GLUE_DEBUG(" >> raw, size = %zu\n", buf.size());
  }
  void serialize(glue_outbuf &output)
  {
    GLUE_DEBUG(" << raw, size = %zu\n", buf.size());
    output.append_u32(dtype);
    output.append_u32(buf.size());
    output.append(buf.data(), buf.size());
  }
};

template <typename T>
struct glue_arr : glue_type_base
{
  std::vector<T> arr;
  std::function<void(T &, glue_outbuf &)> serialize_elem;

  glue_arr(const char *name, glue_handler &handler, glue_dtype dtype) : glue_type_base(name, handler, dtype) {}
  void parse(glue_inbuf &input)
  {
    if (parse_type(input))
      return;
    uint32_t size = input.read_u32();
    GLUE_DEBUG(" >> array[%u]\n", size);
    arr.reserve(size);
    for (uint32_t i = 0; i < size; i++)
    {
      T elem;
      input.read(elem);
      arr.push_back(std::move(elem));
    }
  }
  void serialize(glue_outbuf &output)
  {
    GLUE_DEBUG(" << array[%zu]\n", arr.size());
    output.append_u32(dtype);
    output.append_u32(arr.size());
    for (auto elem : arr)
    {
      serialize_elem(elem, output);
    }
  }
};

#define DEF_GLUE_ARR(tname, dtype, enum_type, serialize_fn)                                                   \
  struct glue_arr_##tname : glue_arr<dtype>                                                                   \
  {                                                                                                           \
    glue_arr_##tname(const char *name, glue_handler &handler) : glue_arr<dtype>(name, handler, enum_type) \
    {                                                                                                         \
      serialize_elem = [](dtype & elem, glue_outbuf & output) serialize_fn;                                   \
    }                                                                                                         \
  };

DEF_GLUE_ARR(bool, uint32_t, GLUE_DTYPE_ARRAY_BOOL, {
  output.append_u32(elem);
})
DEF_GLUE_ARR(int, int32_t, GLUE_DTYPE_ARRAY_INT, {
  output.append_i32(elem);
})
DEF_GLUE_ARR(float, float, GLUE_DTYPE_ARRAY_FLOAT, {
  output.append_f32(elem);
})
DEF_GLUE_ARR(str, std::string, GLUE_DTYPE_ARRAY_STRING, {
  output.append_u32(elem.size());
  output.append_str(elem);
})
DEF_GLUE_ARR(raw, std::vector<uint8_t>, GLUE_DTYPE_ARRAY_RAW, {
  output.append_u32(elem.size());
  output.append((const char*)elem.data(), elem.size());
})

// Message base

void glue_handler::serialize(glue_outbuf &output)
{
  output.clear();
  output.append_u32(GLUE_MAGIC);
  output.append_u32(GLUE_VERSION);
  output.append(name, 8);
  GLUE_DEBUG("Serializing message %s\n", name);
  GLUE_DEBUG("Fields: %zu\n", fields.size());
  for (auto field : fields)
  {
    GLUE_DEBUG("Serializing field %s, type = %d\n", field->name, field->dtype);
    switch (field->dtype)
    {
    case GLUE_DTYPE_NULL:
      output.append_u32(GLUE_DTYPE_NULL);
      break;
    case GLUE_DTYPE_BOOL:
      ((glue_bool *)field)->serialize(output);
      break;
    case GLUE_DTYPE_INT:
      ((glue_int *)field)->serialize(output);
      break;
    case GLUE_DTYPE_FLOAT:
      ((glue_float *)field)->serialize(output);
      break;
    case GLUE_DTYPE_STRING:
      ((glue_str *)field)->serialize(output);
      break;
    case GLUE_DTYPE_RAW:
      ((glue_raw *)field)->serialize(output);
      break;
    case GLUE_DTYPE_ARRAY_BOOL:
      ((glue_arr_bool *)field)->serialize(output);
      break;
    case GLUE_DTYPE_ARRAY_INT:
      ((glue_arr_int *)field)->serialize(output);
      break;
    case GLUE_DTYPE_ARRAY_FLOAT:
      ((glue_arr_float *)field)->serialize(output);
      break;
    case GLUE_DTYPE_ARRAY_STRING:
      ((glue_arr_str *)field)->serialize(output);
      break;
    case GLUE_DTYPE_ARRAY_RAW:
      ((glue_arr_raw *)field)->serialize(output);
      break;
    }
  }
}

void glue_handler::deserialize(glue_inbuf &input)
{
  uint32_t magic = input.read_u32();
  if (magic != GLUE_MAGIC)
  {
    throw std::runtime_error("Invalid magic number");
  }

  uint32_t version = input.read_u32();
  if (version != GLUE_VERSION)
  {
    throw std::runtime_error("Version mismatch");
  }

  std::string proto_id = input.read_str(GLUE_PROTO_ID_LEN);
  if (proto_id != name)
  {
    throw std::runtime_error("Prototype ID mismatch " + proto_id + " != " + name);
  }

  GLUE_DEBUG("Deserializing message %s\n", name);
  for (auto field : fields)
  {
    GLUE_DEBUG("Deserializing field %s, type = %d\n", field->name, field->dtype);
    switch (field->dtype)
    {
    case GLUE_DTYPE_NULL:
      field->parse_type(input);
      break;
    case GLUE_DTYPE_BOOL:
      ((glue_bool *)field)->parse(input);
      break;
    case GLUE_DTYPE_INT:
      ((glue_int *)field)->parse(input);
      break;
    case GLUE_DTYPE_FLOAT:
      ((glue_float *)field)->parse(input);
      break;
    case GLUE_DTYPE_STRING:
      ((glue_str *)field)->parse(input);
      break;
    case GLUE_DTYPE_RAW:
      ((glue_raw *)field)->parse(input);
    case GLUE_DTYPE_ARRAY_BOOL:
      ((glue_arr_bool *)field)->parse(input);
      break;
    case GLUE_DTYPE_ARRAY_INT:
      ((glue_arr_int *)field)->parse(input);
      break;
    case GLUE_DTYPE_ARRAY_FLOAT:
      ((glue_arr_float *)field)->parse(input);
      break;
    case GLUE_DTYPE_ARRAY_STRING:
      ((glue_arr_str *)field)->parse(input);
      break;
    case GLUE_DTYPE_ARRAY_RAW:
      ((glue_arr_raw *)field)->parse(input);
    }
  }
}

template <std::size_t N>
constexpr auto &PROTO_ID(char const (&s)[N])
{
  static_assert(N == GLUE_PROTO_ID_LEN + 1, "Prototype ID must be 8 characters long");
  return s;
}
#define GLUE_FIELD(type, name) glue_##type name = glue_##type(#name, handler);
#define GLUE_FIELD_NULLABLE(type, name) glue_##type name = glue_##type(#name, handler);
#define GLUE_HANDLER(name) glue_handler handler = glue_handler(PROTO_ID(name));

// Message for events

struct glue_msg_error
{
  GLUE_HANDLER("erro_evt")
  GLUE_FIELD(str, message)
};

// Message for actions

struct glue_msg_load_req
{
  GLUE_HANDLER("load_req")
  GLUE_FIELD(arr_str, model_paths)
  GLUE_FIELD_NULLABLE(str, mmproj_path)
  GLUE_FIELD(bool, n_ctx_auto)
  GLUE_FIELD(bool, use_mmap)
  GLUE_FIELD(bool, use_mlock)
  GLUE_FIELD(int, n_gpu_layers)
  GLUE_FIELD(int, n_ctx)
  GLUE_FIELD(int, n_threads)
  GLUE_FIELD_NULLABLE(str, model_alias)
  GLUE_FIELD_NULLABLE(int, log_level)
  GLUE_FIELD_NULLABLE(bool, embeddings)
  GLUE_FIELD_NULLABLE(bool, offload_kqv)
  GLUE_FIELD_NULLABLE(int, n_batch)
  GLUE_FIELD_NULLABLE(int, n_ubatch)
  GLUE_FIELD_NULLABLE(int, n_parallel)
  GLUE_FIELD_NULLABLE(str, pooling_type)
  GLUE_FIELD_NULLABLE(str, rope_scaling_type)
  GLUE_FIELD_NULLABLE(float, rope_freq_base)
  GLUE_FIELD_NULLABLE(float, rope_freq_scale)
  GLUE_FIELD_NULLABLE(float, yarn_ext_factor)
  GLUE_FIELD_NULLABLE(float, yarn_attn_factor)
  GLUE_FIELD_NULLABLE(float, yarn_beta_fast)
  GLUE_FIELD_NULLABLE(float, yarn_beta_slow)
  GLUE_FIELD_NULLABLE(int, yarn_orig_ctx)
  GLUE_FIELD_NULLABLE(str, cache_type_k)
  GLUE_FIELD_NULLABLE(str, cache_type_v)
  GLUE_FIELD_NULLABLE(bool, kv_unified)
  GLUE_FIELD_NULLABLE(bool, flash_attn)
  GLUE_FIELD_NULLABLE(bool, swa_full)
  GLUE_FIELD_NULLABLE(int, n_ctx_checkpoints)
  GLUE_FIELD_NULLABLE(int, checkpoint_min_step)
  GLUE_FIELD_NULLABLE(str, chat_template)
  GLUE_FIELD_NULLABLE(bool, jinja)
  GLUE_FIELD_NULLABLE(arr_str, default_template_kwargs_keys)
  GLUE_FIELD_NULLABLE(arr_str, default_template_kwargs_vals)
  GLUE_FIELD_NULLABLE(bool, reasoning)
  GLUE_FIELD_NULLABLE(int, image_min_tokens)
  GLUE_FIELD_NULLABLE(int, image_max_tokens)
  GLUE_FIELD_NULLABLE(bool, warmup)
  GLUE_FIELD_NULLABLE(bool, no_kv_offload)
  GLUE_FIELD_NULLABLE(bool, mmproj_offload)
  GLUE_FIELD_NULLABLE(bool, cont_batching)
  GLUE_FIELD_NULLABLE(int, n_keep)
  GLUE_FIELD_NULLABLE(bool, ctx_shift)
  GLUE_FIELD_NULLABLE(bool, cache_idle_slots)
  GLUE_FIELD_NULLABLE(int, n_cache_reuse)
  GLUE_FIELD_NULLABLE(arr_str, lora_paths)
  GLUE_FIELD_NULLABLE(arr_float, lora_scales)
  GLUE_FIELD_NULLABLE(bool, lora_init_without_apply)
  GLUE_FIELD_NULLABLE(str, spec_draft_model)
  GLUE_FIELD_NULLABLE(int, spec_draft_ngl)
  GLUE_FIELD_NULLABLE(int, spec_draft_n_max)
  GLUE_FIELD_NULLABLE(int, spec_draft_n_min)
  GLUE_FIELD_NULLABLE(float, spec_draft_p_min)
  GLUE_FIELD_NULLABLE(int, spec_draft_threads)
  GLUE_FIELD_NULLABLE(int, spec_draft_threads_batch)
  GLUE_FIELD_NULLABLE(arr_str, kv_overrides_keys)
  GLUE_FIELD_NULLABLE(arr_str, kv_overrides_vals)
  GLUE_FIELD_NULLABLE(int, reasoning_budget_tokens)
  GLUE_FIELD_NULLABLE(str, reasoning_budget_message)
  GLUE_FIELD_NULLABLE(str, reasoning_format)
  GLUE_FIELD_NULLABLE(bool, skip_chat_parsing)
  GLUE_FIELD_NULLABLE(bool, prefill_assistant)
};

struct glue_msg_load_res
{
  GLUE_HANDLER("load_res")
  GLUE_FIELD(bool, success)
  GLUE_FIELD(int, n_ctx)
  GLUE_FIELD(int, n_batch)
  GLUE_FIELD(int, n_ubatch)
  GLUE_FIELD(int, n_vocab)
  GLUE_FIELD(int, n_ctx_train)
  GLUE_FIELD(int, n_embd)
  GLUE_FIELD(int, n_layer)
  GLUE_FIELD(arr_str, metadata_key)
  GLUE_FIELD(arr_str, metadata_val)
  GLUE_FIELD(int, token_bos)
  GLUE_FIELD(int, token_eos)
  GLUE_FIELD(int, token_eot)
  GLUE_FIELD(arr_int, list_tokens_eog)
  GLUE_FIELD(bool, add_bos_token)
  GLUE_FIELD(bool, add_eos_token)
  GLUE_FIELD(bool, has_encoder)
  GLUE_FIELD(int, token_decoder_start)
  GLUE_FIELD(str, media_marker)
  GLUE_FIELD(bool, has_image_input)
  GLUE_FIELD(bool, has_audio_input)
};

/////////

struct glue_msg_completion_req
{
  GLUE_HANDLER("cmpl_req")
  GLUE_FIELD(bool, is_chat)
  GLUE_FIELD(str, data_json)
  GLUE_FIELD(arr_raw, files)
};

struct glue_msg_completion_res
{
  GLUE_HANDLER("cmpl_res")
  GLUE_FIELD(bool, success)
};

/////////

struct glue_msg_embedding_req
{
  GLUE_HANDLER("embd_req")
  GLUE_FIELD(str, data_json)
  GLUE_FIELD(arr_raw, files)
};

struct glue_msg_embedding_res
{
  GLUE_HANDLER("embd_res")
  GLUE_FIELD(bool, success)
};

/////////

struct glue_msg_rerank_req
{
  GLUE_HANDLER("rrnk_req")
  GLUE_FIELD(str, data_json)
};

struct glue_msg_rerank_res
{
  GLUE_HANDLER("rrnk_res")
  GLUE_FIELD(bool, success)
};

/////////

struct glue_msg_get_result_req
{
  GLUE_HANDLER("gres_req")
};

struct glue_msg_get_result_res
{
  GLUE_HANDLER("gres_res")
  GLUE_FIELD(bool, success)
  GLUE_FIELD(bool, has_more)
  GLUE_FIELD(bool, is_error)
  GLUE_FIELD(str, data_json)
};
