#pragma once
#include "./std_headers.h"
#include <atomic>

/*
  The packetbuf system is a convenient and high-performance way of
  sending typed data around between processes.

  It's *not* architecture independent, so you'd be hosed if you tried
  to communicate between a big-endian and a little-endian machine.

  In the robot environment, I standardize on the Intel binary format, so
  the AVR32 code does endian gymnastics to make that work.

  Reading:

  A packet is a binary blob of data. Walk through it as you read
  data objects one at a time.

  Example: {
    packet rd(1024);
    rd.add_read(fd);

    int foo;
    rd.get(foo);
    float bar;
    rd.get(bar);
    string buz;
    rd.get(buz);
  }

  Writing:

  The interesting work is done in overloaded packet_wr_value functions.
  These should correspond to packet_rd_value functions which extract
  the data item back out.

  Example: {
    packet wr;
    wr.add(17);
    wr.add(3.0);
    wr.add("foo"s);
    write(fd, wr.ptr(), wr.size());
  }

  Supporting your own data types:

  You have to implement packet_wr_value and packet_rd_value.
  Also, packet_get_typetag

  SECURITY: there's some attempt at input validation, but it might have bugs.
  If you need to lock it down for external input, carefully go through all
  the size calculations on 32 and 64-bit architectures.

*/

struct packet_contents;
struct packet_annotations;
struct packet;

/*
  This is the actual data in the packet
*/
struct packet_contents {
  std::atomic<int> refcnt;
  size_t alloc;
  uint8_t buf[1];
};

struct packet_annotations {
  packet_annotations() = default;
  std::atomic<int> refcnt = 0;
  map<string, string> table;
};

// ----------------------------------------------------------------------

struct packet_wr_overrun_err : overflow_error {
  explicit packet_wr_overrun_err(int _howmuch);
  ~packet_wr_overrun_err();
  int howmuch;
};

struct packet_rd_overrun_err : overflow_error {
  explicit packet_rd_overrun_err(int _howmuch);
  ~packet_rd_overrun_err();
  int howmuch;
};

struct packet_rd_type_err : invalid_argument {
  explicit packet_rd_type_err(string const &_expected, string const &_got);
  ~packet_rd_type_err();
  string expected;
  string got;
};

struct packet_stats {
  int incref_count;
  int decref_count;
  int alloc_count;
  int free_count;
  int cow_count;
  int expand_count;
  long long copy_bytes_count;
};


#define DECL_PACKETIO(T) \
  namespace packetio { \
    ::std::string packet_get_typetag(T const &x); \
    void packet_wr_value(::packet &p, T const &x); \
    void packet_rd_value(::packet &p, T &x); \
    std::function<void(packet &, T &)> packet_rd_value_compat(T const &x, ::std::string const &typetag); \
  }

DECL_PACKETIO(bool);
DECL_PACKETIO(char);
DECL_PACKETIO(S8);
DECL_PACKETIO(U8);
DECL_PACKETIO(S16);
DECL_PACKETIO(U16);
DECL_PACKETIO(S32);
DECL_PACKETIO(U32);
DECL_PACKETIO(S64);
DECL_PACKETIO(U64);
DECL_PACKETIO(float);
DECL_PACKETIO(double);
DECL_PACKETIO(timeval);
DECL_PACKETIO(string);
DECL_PACKETIO(char const *);
DECL_PACKETIO(std::complex<double>);

namespace packetio {

/*
  Any vector is handled by writing a size followed by the items. Watch
  out for heap overflows. stl_vector seems to protect against this by
  computing maximum size (# elements) as (size_t)-1 / sizeof(ITEM),
  but we check again here by comparing against p.remaining() / sizeof(ITEM).

  We use uint32_t rather than size_t for compatibility
*/

template <typename T>
void packet_wr_value(packet &p, vector<T> const &x);

template <typename T>
string packet_get_typetag(vector<T> const &x);

template <typename T>
void packet_rd_value(packet &p, vector<T> &x);


template <>
void packet_wr_value(packet &p, vector<U8> const &x);

template <>
void packet_rd_value(packet &p, vector<U8> &x);


// ----------------------------------------------------------------------

template <typename T1, typename T2>
string packet_get_typetag(pair<T1, T2> const &x);

template <typename T1, typename T2>
void packet_wr_value(packet &p, pair<T1, T2> const &x);

template <typename T1, typename T2>
void packet_rd_value(packet &p, pair<T1, T2> &x);

template <typename T1, typename T2>
string packet_get_typetag(map<T1, T2> const &x);

template <typename T1, typename T2>
void packet_wr_value(packet &p, map<T1, T2> const &x);

template <typename T1, typename T2>
void packet_rd_value(packet &p, map<T1, T2> &x);
} // namespace packetio

// ----------------------------------------------------------------------

struct gzFile_s;

struct packet {
  packet();
  explicit packet(size_t size);
  explicit packet(u_char const *data, size_t size);
  explicit packet(string const &data);
  packet(const packet &other);
  packet(packet &&other) noexcept
      : contents(other.contents), annotations(other.annotations), rd_pos(other.rd_pos), wr_pos(other.wr_pos)
  {
    other.contents = nullptr;
    other.annotations = nullptr;
  }

  packet &operator=(packet const &other);
  packet &operator=(packet &&other) noexcept;
  ~packet();

  string as_string();

  int rd_to_file(int fd) const;
  int rd_to_file(FILE *fp) const;
  int to_file(int fd) const;
  int to_file(FILE *fp) const;
  string dump_hex() const;
  void dump(FILE *fp = stderr) const;

  void to_file_boxed(int fd) const;
  void to_gzfile_boxed(gzFile_s *fp) const;
  static packet from_file_boxed(int fd);
  static packet from_gzfile_boxed(gzFile_s *fp);

  size_t size() const;
  size_t size_bits() const;
  float size_kbits() const;
  size_t alloc() const;
  void resize(size_t newsize);
  void make_mutable();
  void clear();

  const uint8_t *wr_ptr() const; // pointer to the write position at the end of the packet
  const uint8_t *rd_ptr() const; // pointer to the read position
  const uint8_t *ptr() const;    // pointer to the beginning
  const uint8_t *begin() const;  // same as ptr()
  const uint8_t *end() const;    // pointer to the end, same as wr_ptr()
  uint8_t operator[](int index) const;

  uint8_t *wr_ptr();
  uint8_t *rd_ptr();
  uint8_t *ptr();
  uint8_t *begin();
  uint8_t *end();
  uint8_t &operator[](int index);

  string &annotation(string const &key);
  string annotation(string const &key) const;
  bool has_annotation(string const &key) const;

  // writing

  // Embedded packet, with length prefix
  void add_pkt(packet const &wr);

  // Raw bytes. Receiver must know length to recover
  void add_bytes(u_char const *data, size_t size); // preferred
  void add_bytes(char const *data, size_t size);
  void add_reversed(u_char const *data, size_t size);

  // String terminated by a newline. Use get_nl_string on the receiver
  void add_nl_string(char const *s);
  void add_nl_string(string const &s);

  // String prefixed with an 8-bit length. Use get_len8_string on the receiver.
  void add_len8_string(string const &s);

  // String with no prefix terminator. The corresponding get_remainder_string gets the remainder of the packet.
  void add_remainder_string(string const &s);

  // Add a tag, which can be read with test_typetag (returning false if no match) or check_typetag (throwing if no
  // match)
  void add_literal_typetag(char const *tag);

  template <typename T>
  void add_checked(const T &x);
  template <typename T>
  void add(const T &x);
  template <typename T>
  void add_typetag(const T &x);

  // Big-endian numbers.
  void add_be_uint64(uint64_t x);
  void add_be_uint32(uint32_t x);
  void add_be_uint24(uint32_t x);
  void add_be_uint16(uint32_t x);
  void add_be_uint8(uint32_t x);
  void add_be_double(double x);
  void add_be_float(float x);

  // Little-endian numbers
  void add_le_uint64(uint64_t x);
  void add_le_uint32(uint32_t x);
  void add_le_uint24(uint32_t x);
  void add_le_uint16(uint32_t x);
  void add_le_uint8(uint32_t x);
  void add_le_double(double x);
  void add_le_float(float x);

  // Contents of a file, no prefix
  ssize_t add_read(int fd, size_t readsize);
  ssize_t add_pread(int fd, size_t readsize, off_t offset);
  ssize_t add_read(FILE *fp, size_t readsize);
  void add_file_contents(int fd);
  void add_file_contents(FILE *fp);

  // reading
  void rewind();
  ssize_t remaining() const;
  packet get_remainder();

  void get_skip(int n);

  bool get_test(uint8_t *data, size_t size);  // returns false if it fails
  void get_bytes(uint8_t *data, size_t size); // throws packet_rd_overrun_err if it fails
  void get_bytes(char *data, size_t size);    // throws packet_rd_overrun_err if it fails

  void get_reversed(uint8_t *data, size_t size);

  template <typename T>
  void get(T &x);
  template <typename T>
  void get_checked(T &x);
  template <typename T>
  void get_compat(T &x);
  template <typename T>
  std::function<void(packet &, T &)> get_compat_func(T const &x, string const &typetag);
  template <typename T>
  T fget();
  template <typename T>
  T fget_checked();
  template <typename T>
  T fget_compat();

  packet get_pkt();
  bool get_bool();
  char get_char();

  uint64_t get_be_uint64();
  uint32_t get_be_uint32();
  uint32_t get_be_uint24();
  uint16_t get_be_uint16();
  uint8_t get_be_uint8();
  int64_t get_be_int64();
  int32_t get_be_int32();
  int32_t get_be_int24();
  int16_t get_be_int16();
  int8_t get_be_int8();
  double get_be_double();
  float get_be_float();

  uint64_t get_le_uint64();
  uint32_t get_le_uint32();
  uint32_t get_le_uint24();
  uint16_t get_le_uint16();
  uint8_t get_le_uint8();
  int64_t get_le_int64();
  int32_t get_le_int32();
  int32_t get_le_int24();
  int16_t get_le_int16();
  int8_t get_le_int8();
  double get_le_double();
  float get_le_float();

  string get_nl_string();
  string get_len8_string();
  string get_lenbe16_string();
  string get_lenbe32_string();
  string get_string(size_t size);
  string get_remainder_string();

  void check_at_end();

  void write_to_file(char const *fn) const;

  static packet read_from_file(char const *fn);
  static packet read_from_fd(int fd);

  // stats
  static string stats_str();
  static void clear_stats();

  // internals
  static packet_contents *alloc_contents(size_t alloc);
  static void decref(packet_contents *&it);
  static void incref(packet_contents *it);
  inline void reserve(size_t new_size) {
    if (new_size > contents->alloc || contents->refcnt > 1) {
      reserve_full(new_size);
    }
  }
  void reserve_full(size_t new_size);
  static void decref(packet_annotations *&it);
  static void incref(packet_annotations *it);

  // tests
  static string run_test(int testid);

  // ------------
  packet_contents *contents{nullptr};
  packet_annotations *annotations{nullptr};
  size_t rd_pos{0};
  size_t wr_pos{0};

  static packet_stats stats;
};

bool operator==(packet const &a, packet const &b);

// ----------------------------------------------------------------------

using packet_queue = deque<packet>;

ostream &operator<<(ostream &s, packet const &it);

namespace packetio {

template <typename T>
string packet_get_typetag(vector<T> const &x);

template <typename T>
void packet_wr_value(packet &p, vector<T> const &x);

template <typename T>
void packet_rd_value(packet &p, vector<T> &x);

template <typename T1, typename T2>
string packet_get_typetag(pair<T1, T2> const &x);

template <typename T1, typename T2>
void packet_wr_value(packet &p, pair<T1, T2> const &x);

template <typename T1, typename T2>
void packet_rd_value(packet &p, pair<T1, T2> &x);

template <typename T1, typename T2>
string packet_get_typetag(map<T1, T2> const &x);

template <typename T1, typename T2>
void packet_wr_value(packet &p, map<T1, T2> const &x);

template <typename T1, typename T2>
void packet_rd_value(packet &p, map<T1, T2> &x);

} // namespace packetio
