
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include "base64.h"

void assert_equal(unsigned char *data, int data_length, unsigned char *data2, int data2_length) {
  
  assert(data2_length == data_length);
  
  int i;
  int same = 1;
  for (i = 0; i < data2_length; i++) {
    if (data[i] != data2[i]) {
      same = 0;
    }
  }
  
  if ( ! same) {
    for (i = 0; i < data2_length; i++) {
      fprintf(stderr, "%.2x vs %.2x", data[i], data2[i]);
      if (data[i] != data2[i]) {
        fprintf(stderr, " <<<<<<");
      }
      fprintf(stderr, "\n");
    }
    assert(0);
  }
}


void test_data(unsigned char *data, int data_length) {
  unsigned char *data2;
  int data2_length;
  char *base64 = base64_encode(data, data_length);
  
  data2 = base64_decode(base64, &data2_length);
  
  //fprintf(stderr, "[%d %d %d] -> %s -> [%d %d %d]\n", (int)data[0], (int)data[1], (int)data[2], base64, (int)data2[1], (int)data2[2], (int)data2[3]);
  assert_equal(data, data_length, data2, data2_length);
}


int main() {
  
  int data_length, data2_length;
  unsigned char d[200];
  unsigned char *data2;
  char *base64;
  
  assert(base64_length_for_data_length(0) == 0);
  assert(base64_length_for_data_length(1) == 4);
  assert(base64_length_for_data_length(2) == 4);
  assert(base64_length_for_data_length(3) == 4);
  assert(base64_length_for_data_length(4) == 8);
  
  assert(num_base64_padding_chars_for_data_length(0) == 0);
  assert(num_base64_padding_chars_for_data_length(1) == 2);
  assert(num_base64_padding_chars_for_data_length(2) == 1);
  assert(num_base64_padding_chars_for_data_length(3) == 0);
  assert(num_base64_padding_chars_for_data_length(4) == 2);
  assert(num_base64_padding_chars_for_data_length(5) == 1);
  assert(num_base64_padding_chars_for_data_length(6) == 0);
  
  assert(data_length_for_base64("") == 0);
  assert(data_length_for_base64("AA==") == 1);
  assert(data_length_for_base64("AAA=") == 2);
  assert(data_length_for_base64("AAAA") == 3);
  assert(data_length_for_base64("AAAAAA==") == 4);
  assert(data_length_for_base64("AAAAAAA=") == 5);
  assert(data_length_for_base64("AAAAAAAA") == 6);
  
  base64 = "";     data2 = base64_decode(base64, &data2_length); assert(data2_length == 0);
  base64 = "AA=="; data2 = base64_decode(base64, &data2_length); assert(data2_length == 1);
  base64 = "AAA="; data2 = base64_decode(base64, &data2_length); assert(data2_length == 2);
  base64 = "AAAA"; data2 = base64_decode(base64, &data2_length); assert(data2_length == 3);
  base64 = "AAAAAA=="; data2 = base64_decode(base64, &data2_length); assert(data2_length == 4);
  
  d[0] = d[1] = d[2] = d[3] = 0;
  base64 = base64_encode(d, 0); assert(0 == strcmp(base64, ""));
  base64 = base64_encode(d, 1); assert(0 == strcmp(base64, "AA=="));
  base64 = base64_encode(d, 2); assert(0 == strcmp(base64, "AAA="));
  base64 = base64_encode(d, 3); assert(0 == strcmp(base64, "AAAA"));
  base64 = base64_encode(d, 4); assert(0 == strcmp(base64, "AAAAAA=="));
  
  int i, j;
  for (i = 0; i < 256; i++) {
    for (j = 0; j < 256; j++) {
      d[0] = i;
      d[1] = j;
      d[2] = (i + j) % 256;
      test_data(d, 3);
      d[0] = (i + j) % 256;
      d[1] = i;
      d[2] = j;
      test_data(d, 3);
    }
  }
  
  d[0] = 0x7b;
  d[1] = 0x22;
  d[2] = 0x74;
  d[3] = 0x6f;
  d[4] = 0x22;
  d[5] = 0x3a;
  d[6] = 0x22;
  d[7] = 0x6a;
  d[8] = 0x73;
  d[9] = 0x22;
  d[10] = 0x2c;
  d[11] = 0x22;
  d[12] = 0x69;
  d[13] = 0x64;
  d[14] = 0x22;
  d[15] = 0x3a;
  d[16] = 0x31;
  d[17] = 0x2c;
  d[18] = 0x22;
  d[19] = 0x72;
  d[20] = 0x65;
  d[21] = 0x73;
  d[22] = 0x75;
  d[23] = 0x6c;
  d[24] = 0x74;
  d[25] = 0x22;
  d[26] = 0x3a;
  d[27] = 0x7b;
  d[28] = 0x7d;
  d[29] = 0x2c;
  d[30] = 0x22;
  d[31] = 0x66;
  d[32] = 0x72;
  d[33] = 0x6f;
  d[34] = 0x6d;
  d[35] = 0x22;
  d[36] = 0x3a;
  d[37] = 0x22;
  d[38] = 0x73;
  d[39] = 0x75;
  d[40] = 0x62;
  d[41] = 0x70;
  d[42] = 0x72;
  d[43] = 0x6f;
  d[44] = 0x63;
  d[45] = 0x65;
  d[46] = 0x73;
  d[47] = 0x73;
  d[48] = 0x22;
  d[49] = 0x7d;
  data_length = 50;
  data2 = base64_decode("eyJ0byI6ImpzIiwiaWQiOjEsInJlc3VsdCI6e30sImZyb20iOiJzdWJwcm9jZXNzIn0=", &data2_length);
  
  assert_equal(d, data_length, data2, data2_length);
  
  fprintf(stderr, "OK\n");
  
  return 0;
}
