
import "fmt"

// Parse out optional arguments for sign and verify.
//   aug []byte - augmentation bytes (default: nil)
func parseOpts(optional ...interface{}) ([]byte, [][]byte, bool, bool) {
    var aug [][]byte     // For aggregate verify
    var augSingle []byte // For signing
    useHash := true      // hash (true), encode (false)

    for _, arg := range optional {
        switch v := arg.(type) {
        case []byte:
            augSingle = v
        case [][]byte:
            aug = v
        case bool:
            useHash = v
        default:
            return nil, nil, useHash, false
        }
    }
    return augSingle, aug, useHash, true
}

func bytesAllZero(s []byte) bool {
    for _, v := range s {
        if v != 0 {
            return false
        }
    }
    return true
}

//
// Serialization/Deserialization.
//

// Scalar serdes
func (s *Scalar) Serialize() []byte {
    var out [BLST_SCALAR_BYTES]byte
    C.blst_bendian_from_scalar((*C.byte)(&out[0]), s)
    return out[:]
}

func (s *Scalar) Deserialize(in []byte) *Scalar {
    if len(in) != BLST_SCALAR_BYTES {
        return nil
    }
    C.blst_scalar_from_bendian(s, (*C.byte)(&in[0]))
    if !C.blst_sk_check(s) {
        return nil
    }
    return s
}

func (s *Scalar) Valid() bool {
    return bool(C.blst_sk_check(s))
}

func (s *Scalar) HashTo(msg []byte, dst []byte) bool {
    ret := HashToScalar(msg, dst)
    if ret != nil {
        *s = *ret
        return true
    }
    return false
}

func HashToScalar(msg []byte, dst []byte) *Scalar {
    var ret Scalar
    var elem [48]C.byte

    var msgC *C.byte
    if len(msg) > 0 {
        msgC = (*C.byte)(&msg[0])
    }

    var dstC *C.byte
    if len(dst) > 0 {
        dstC = (*C.byte)(&dst[0])
    }

    C.blst_expand_message_xmd(&elem[0], C.size_t(len(elem)),
                              msgC, C.size_t(len(msg)),
                              dstC, C.size_t(len(dst)))
    if C.blst_scalar_from_be_bytes(&ret, &elem[0], C.size_t(len(elem))) {
        return &ret
    }

    return nil
}

//
// LEndian
//

func (fr *Scalar) ToLEndian() []byte {
    var arr [BLST_SCALAR_BYTES]byte
    C.blst_lendian_from_scalar((*C.byte)(&arr[0]), fr)
    return arr[:]
}

func (fp *Fp) ToLEndian() []byte {
    var arr [BLST_FP_BYTES]byte
    C.blst_lendian_from_fp((*C.byte)(&arr[0]), fp)
    return arr[:]
}

func (fr *Scalar) FromLEndian(arr []byte) *Scalar {
    nbytes := len(arr)
    if nbytes == BLST_SCALAR_BYTES {
        C.blst_scalar_from_lendian(fr, (*C.byte)(&arr[0]))
    } else if nbytes > BLST_SCALAR_BYTES {
        C.blst_scalar_from_le_bytes(fr, (*C.byte)(&arr[0]), C.size_t(nbytes))
    } else {
        return nil
    }
    return fr
}

func (fp *Fp) FromLEndian(arr []byte) *Fp {
    if len(arr) != BLST_FP_BYTES {
        return nil
    }
    C.blst_fp_from_lendian(fp, (*C.byte)(&arr[0]))
    return fp
}

//
// BEndian
//

func (fr *Scalar) ToBEndian() []byte {
    var arr [BLST_SCALAR_BYTES]byte
    C.blst_bendian_from_scalar((*C.byte)(&arr[0]), fr)
    return arr[:]
}

func (fp *Fp) ToBEndian() []byte {
    var arr [BLST_FP_BYTES]byte
    C.blst_bendian_from_fp((*C.byte)(&arr[0]), fp)
    return arr[:]
}

func (fr *Scalar) FromBEndian(arr []byte) *Scalar {
    nbytes := len(arr)
    if nbytes == BLST_SCALAR_BYTES {
        C.blst_scalar_from_bendian(fr, (*C.byte)(&arr[0]))
    } else if nbytes > BLST_SCALAR_BYTES {
        C.blst_scalar_from_be_bytes(fr, (*C.byte)(&arr[0]), C.size_t(nbytes))
    } else {
        return nil
    }
    return fr
}

func (fp *Fp) FromBEndian(arr []byte) *Fp {
    if len(arr) != BLST_FP_BYTES {
        return nil
    }
    C.blst_fp_from_bendian(fp, (*C.byte)(&arr[0]))
    return fp
}

//
// Printing
//

func PrintBytes(val []byte, name string) {
    fmt.Printf("%s = %02x\n", name, val)
}

func (s *Scalar) Print(name string) {
    arr := s.ToBEndian()
    PrintBytes(arr[:], name)
}

func (p *P1Affine) Print(name string) {
    fmt.Printf("%s:\n", name)
    arr := p.x.ToBEndian()
    PrintBytes(arr, "  x")
    arr = p.y.ToBEndian()
    PrintBytes(arr, "  y")
}

func (p *P1) Print(name string) {
    fmt.Printf("%s:\n", name)
    aff := p.ToAffine()
    aff.Print(name)
}

func (f *Fp2) Print(name string) {
    fmt.Printf("%s:\n", name)
    arr := f.fp[0].ToBEndian()
    PrintBytes(arr, "    0")
    arr = f.fp[1].ToBEndian()
    PrintBytes(arr, "    1")
}

func (p *P2Affine) Print(name string) {
    fmt.Printf("%s:\n", name)
    p.x.Print("  x")
    p.y.Print("  y")
}

func (p *P2) Print(name string) {
    fmt.Printf("%s:\n", name)
    aff := p.ToAffine()
    aff.Print(name)
}

//
// Equality
//

func (s1 *Scalar) Equals(s2 *Scalar) bool {
    return *s1 == *s2;
}

func (e1 *Fp) Equals(e2 *Fp) bool {
    return *e1 == *e2;
}

func (e1 *Fp2) Equals(e2 *Fp2) bool {
    return *e1 == *e2;
}

func (e1 *P1Affine) Equals(e2 *P1Affine) bool {
    return bool(C.blst_p1_affine_is_equal(e1, e2))
}

func (e1 *P1) Equals(e2 *P1) bool {
    return bool(C.blst_p1_is_equal(e1, e2))
}

func (e1 *P2Affine) Equals(e2 *P2Affine) bool {
    return bool(C.blst_p2_affine_is_equal(e1, e2))
}

func (e1 *P2) Equals(e2 *P2) bool {
    return bool(C.blst_p2_is_equal(e1, e2))
}

// private thunk for testing

func expandMessageXmd(msg []byte, dst []byte, len_in_bytes int) []byte {
    ret := make([]byte, len_in_bytes)

    var msgC *C.byte
    if len(msg) > 0 {
        msgC = (*C.byte)(&msg[0])
    }

    var dstC *C.byte
    if len(dst) > 0 {
        dstC = (*C.byte)(&dst[0])
    }

    C.blst_expand_message_xmd((*C.byte)(&ret[0]), C.size_t(len(ret)),
                              msgC, C.size_t(len(msg)),
                              dstC, C.size_t(len(dst)))
    return ret
}

func breakdown(nbits, window, ncpus int) (int, int, int) {
    var nx, ny, wnd int

    if nbits > window*ncpus {
        nx = 1
        wnd = window - bits.Len(uint(ncpus)/4)
    } else {
        nx = 2
        wnd = window-2
        for (nbits/wnd+1)*nx < ncpus {
            nx += 1
            wnd = window - bits.Len(3*uint(nx)/2)
        }
        nx -= 1
        wnd = window - bits.Len(3*uint(nx)/2)
    }
    ny = nbits/wnd + 1
    wnd = nbits/ny + 1

    return nx, ny, wnd
}

func pippenger_window_size(npoints int) int {
    wbits := bits.Len(uint(npoints))

    if wbits > 13 {
        return wbits - 4
    }
    if wbits > 5 {
        return wbits - 3
    }
    return 2
}
