
#include <stdlib.h>
#include <string.h>
#include "base64.h"

// Base64:
//    2**8 = 256
//    2**6 = 64
//    8 * 3 = 6 * 4 = 24
//    (num padding chars) \in {0,1,2}

// This implementation does not try to be clever or maximally efficient.
// It attempts to be readable and correct.

char *ALPHABET64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";


char* base64_encode(const unsigned char *data, int data_length) {
  
  // result string with null terminator
  int base64len = base64_length_for_data_length(data_length);
  char *base64 = malloc(base64len + 1);
  base64[base64len] = 0;
  
  // blocks
  int numPadingChars = num_base64_padding_chars_for_data_length(data_length);
  int numBlocks = base64len / 4;
  int i;
  for (i = 0; i < numBlocks; i++) {
    base64encode_block(data, base64, i, ((i == (numBlocks - 1)) ? numPadingChars : 0));
  }
  
  return base64;
}

unsigned char* base64_decode(const char *base64, int *result_length) {
  
  // build table
  int i;
  unsigned char ALPHABET64_TO_INDEX[256];
  for (i = 0; i < 256; i++)   {ALPHABET64_TO_INDEX[i] = -1;}
  for (i = 0; i < 64; i++)    {ALPHABET64_TO_INDEX[ALPHABET64[i]] = i;}
  
  // allocate data
  int data_length = data_length_for_base64(base64);
  *result_length = data_length;
  unsigned char *data = malloc(data_length);
  
  // process blocks
  int numBlocks = (int)strlen(base64) / 4;
  for (i = 0; i < numBlocks; i++) {
    base64decode_block(base64, data, i, ALPHABET64_TO_INDEX);
  }
  
  return data;
}

int base64_length_for_data_length(int n) {
  // (n + 2)mod 3   (n - ... + 2)
  //  2                   0        0 -> 0
  //  0                   3        1 -> 4
  //  1                   3        2 -> 4
  //  2                   3        3 -> 4
  //  0                   6        4 -> 8
  int numBlocks = (n - ((n + 2) % 3) + 2) / 3;
  return numBlocks * 4;
}

int num_base64_padding_chars_for_data_length(int n) {
  //           mod 3
  // 0 -> 0     0
  // 1 -> 2     1
  // 2 -> 1     2
  // 3 -> 0     0
  // 4 -> 2     1
  // 5 -> 1     2
  return (3 - (n % 3)) % 3;
}

int data_length_for_base64(const char *base64) {
  int len = (int)strlen(base64);
  int numBlocks = len / 4;
  int numPaddingChars = 0;
  if (numBlocks == 0) {
    return 0;
  }
  else {
    if (base64[len - 1] == '=') {numPaddingChars++;};
    if (base64[len - 2] == '=') {numPaddingChars++;};
    return (3 * numBlocks) - numPaddingChars;
  }
}

void base64encode_block(const unsigned char *src, char *dest, int blockNum, int numPadingChars) {
  int src_pos = blockNum * 3;
  int dest_pos = blockNum * 4;
  unsigned char o1 = src[src_pos];
  unsigned char o2 = (numPadingChars < 2) ? src[src_pos + 1] : 0;
  unsigned char o3 = (numPadingChars < 1) ? src[src_pos + 2] : 0;
  
  // o: 111111 112222 222233 333333
  // c: 111111 222222 333333 444444
  // 0 b 0011 1111 = 0x3F
  
  dest[dest_pos + 0] = ALPHABET64[(o1 >> 2)];
  dest[dest_pos + 1] = ALPHABET64[((o1 << 4) & 0x3F) | (o2 >> 4)];
  dest[dest_pos + 2] = ALPHABET64[((o2 << 2) & 0x3F) | (o3 >> 6)];
  dest[dest_pos + 3] = ALPHABET64[o3 & 0x3F];
  
  if (numPadingChars == 2) {dest[dest_pos + 2] = '=';}
  if (numPadingChars >= 1) {dest[dest_pos + 3] = '=';}
}

void base64decode_block(const char *src, unsigned char *dest, int blockNum, unsigned char *ALPHABET64_TO_INDEX) {
  int src_pos = blockNum * 4;
  int dest_pos = blockNum * 3;
  char c1 = src[src_pos];
  char c2 = src[src_pos + 1];
  char c3 = src[src_pos + 2];
  char c4 = src[src_pos + 3];
  int i1 = ALPHABET64_TO_INDEX[c1];
  int i2 = ALPHABET64_TO_INDEX[c2];
  int i3 = ALPHABET64_TO_INDEX[c3];
  int i4 = ALPHABET64_TO_INDEX[c4];
  
  // c: 11111122 22223333 33444444
  // o: 11111111 22222222 33333333
  
  dest[dest_pos] = (i1 << 2) | (i2 >> 4);
  int n = 0;
  if (c4 == '=') {n++;}
  if (c3 == '=') {n++;}
  if (n <= 1) {dest[dest_pos + 1] = ((i2 << 4) & 0xFF) | (i3 >> 2);}
  if (n == 0) {dest[dest_pos + 2] = ((i3 << 6) & 0xFF) | (i4);}
}


