package prover

import (
	"bytes"
	"fmt"
	"math/big"

	"github.com/consensys/gnark-crypto/ecc"
	"github.com/consensys/gnark-crypto/ecc/bn254/fr"
	"github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc"
	"github.com/consensys/gnark/backend/groth16"
	"github.com/consensys/gnark/backend/witness"
	"github.com/consensys/gnark/frontend"
)

// VerifyProofResult verifies a ProofResult against a compiled circuit's verifying key.
// The proof must have been generated by the same circuit.
func VerifyProofResult(p *Prover, circuitName string, proof *ProofResult) error {
	cc, ok := p.GetCircuit(circuitName)
	if !ok {
		return fmt.Errorf("circuit %q not registered", circuitName)
	}

	// Reconstruct gnark proof from ProofResult
	gnarkProof := groth16.NewProof(ecc.BN254)
	if err := reconstructProof(gnarkProof, proof); err != nil {
		return fmt.Errorf("invalid proof format: %w", err)
	}

	// Build public witness from PublicInputs
	pubWitness, err := buildPublicWitness(circuitName, proof.PublicInputs)
	if err != nil {
		return fmt.Errorf("invalid public inputs: %w", err)
	}

	// Verify
	if err := groth16.Verify(gnarkProof, cc.VerifyingKey, pubWitness); err != nil {
		return fmt.Errorf("proof verification failed: %w", err)
	}

	return nil
}

// reconstructProof fills a gnark proof from the ProofResult's raw proof data.
func reconstructProof(proof groth16.Proof, result *ProofResult) error {
	if len(result.RawProof) < 8 {
		return fmt.Errorf("raw_proof must have at least 8 elements, got %d", len(result.RawProof))
	}

	// RawProof format: [A.X, A.Y, B.X[0], B.X[1], B.Y[0], B.Y[1], C.X, C.Y]
	// Use the proof's binary serialization roundtrip via gnark's internal format
	// The simplest approach: use the A, B, C points directly

	// gnark's Proof interface doesn't expose setters, so we serialize and deserialize
	// using the proof's own WriteTo/ReadFrom with the raw bytes
	// Actually, gnark's BN254 proof can be reconstructed from the raw points

	// For BN254, proof points are in the format expected by the Solidity verifier
	// We need to reconstruct via the internal gnark representation
	// The most reliable method is to re-prove, but we can also use the raw proof bytes

	// Since the ProofResult was generated by our own prover (not external),
	// and we need to verify it server-side, we should store the serialized proof bytes
	// alongside the result. For now, use a simpler approach: re-verify using
	// the witness reconstruction path.

	return fmt.Errorf("direct proof reconstruction not supported; use VerifyVoteCastWitness instead")
}

// VerifyVoteCastWitness verifies a voteCast proof by re-generating and verifying it.
// This is the server-side verification for off-chain polls.
// publicInputs: [pollId, voterRegistryRoot, nullifier] as hex strings.
func VerifyVoteCastWitness(p *Prover, witnessData map[string]string) error {
	// Build the circuit assignment
	factory := &ArcnetWitnessFactory{}
	assignment, err := factory.CreateAssignment("voteCast", witnessData)
	if err != nil {
		return fmt.Errorf("invalid witness: %w", err)
	}

	// Use the Prover's Verify method which proves and verifies in one step
	return p.Verify("voteCast", assignment)
}

// VerifyVoteCastProofBytes verifies a voteCast proof from raw gnark bytes.
// This path never sees private inputs (voterSecret, voteChoice).
func VerifyVoteCastProofBytes(p *Prover, proofBytes, pubWitnessBytes []byte) error {
	cc, ok := p.GetCircuit("voteCast")
	if !ok {
		return fmt.Errorf("voteCast circuit not registered")
	}

	proof := groth16.NewProof(ecc.BN254)
	if _, err := proof.ReadFrom(bytes.NewReader(proofBytes)); err != nil {
		return fmt.Errorf("invalid proof bytes: %w", err)
	}

	pubWitness, err := witness.New(ecc.BN254.ScalarField())
	if err != nil {
		return fmt.Errorf("witness creation failed: %w", err)
	}
	if _, err := pubWitness.ReadFrom(bytes.NewReader(pubWitnessBytes)); err != nil {
		return fmt.Errorf("invalid public witness bytes: %w", err)
	}

	if err := groth16.Verify(proof, cc.VerifyingKey, pubWitness); err != nil {
		return fmt.Errorf("proof verification failed: %w", err)
	}
	return nil
}

// ValidateVoteCastPublicInputs checks that a voteCast proof's public inputs
// match the expected poll parameters.
func ValidateVoteCastPublicInputs(publicInputs []string, expectedPollID, expectedRegistryRoot string) error {
	if len(publicInputs) < 5 {
		return fmt.Errorf("voteCast requires 5 public inputs (pollId, registryRoot, nullifier, voteCommitment, maxChoices), got %d", len(publicInputs))
	}

	// Public inputs order matches circuit definition: PollID, VoterRegistryRoot, Nullifier, VoteCommitment
	proofPollID := publicInputs[0]
	proofRegistryRoot := publicInputs[1]

	// Normalize hex strings for comparison
	if !bigIntEqual(proofPollID, expectedPollID) {
		return fmt.Errorf("proof pollId mismatch: proof=%s expected=%s", proofPollID, expectedPollID)
	}

	if !bigIntEqual(proofRegistryRoot, expectedRegistryRoot) {
		return fmt.Errorf("proof registryRoot mismatch: proof=%s expected=%s", proofRegistryRoot, expectedRegistryRoot)
	}

	return nil
}

// buildPublicWitness constructs a gnark public witness from hex-encoded public inputs.
func buildPublicWitness(circuitName string, publicInputs []string) (witness.Witness, error) {
	// Create a circuit assignment with only the public fields set
	switch circuitName {
	case "voteCast":
		if len(publicInputs) < 5 {
			return nil, fmt.Errorf("voteCast requires 5 public inputs, got %d", len(publicInputs))
		}
		assignment := &VoteCastCircuit{
			PollID:            parseBigIntOrZero(publicInputs[0]),
			VoterRegistryRoot: parseBigIntOrZero(publicInputs[1]),
			Nullifier:         parseBigIntOrZero(publicInputs[2]),
			VoteCommitment:    parseBigIntOrZero(publicInputs[3]),
			MaxChoices:        parseBigIntOrZero(publicInputs[4]),
		}
		w, err := frontend.NewWitness(assignment, ecc.BN254.ScalarField(), frontend.PublicOnly())
		if err != nil {
			return nil, fmt.Errorf("public witness creation: %w", err)
		}
		return w, nil
	default:
		return nil, fmt.Errorf("unsupported circuit for verification: %s", circuitName)
	}
}

// ValidateVoteReveal checks that mimcHash(voterSecret, voteChoice) == storedCommitment.
// This is used during the reveal phase after the poll closes.
func ValidateVoteReveal(voterSecret string, voteChoice int, storedCommitment string) error {
	secret := parseBigIntOrZero(voterSecret).(*big.Int)
	choice := big.NewInt(int64(voteChoice))

	// Compute mimcHash(voterSecret, voteChoice) using gnark-crypto's MiMC
	computed := MiMCHashBigInt(secret, choice)

	stored := parseBigIntOrZero(storedCommitment).(*big.Int)
	if computed.Cmp(stored) != 0 {
		return fmt.Errorf("commitment mismatch: mimcHash(secret, %d) != stored commitment", voteChoice)
	}
	return nil
}

// bigIntEqual compares two big.Int values given as hex or decimal strings.
func bigIntEqual(a, b string) bool {
	ai := parseBigIntOrZero(a)
	bi := parseBigIntOrZero(b)
	if ai == nil || bi == nil {
		return a == b
	}
	return ai.(*big.Int).Cmp(bi.(*big.Int)) == 0
}

// MiMCHashBigInt computes MiMC hash of two big.Int values.
// Matches the gnark circuit's mimcHash exactly.
func MiMCHashBigInt(a, b *big.Int) *big.Int {
	var fa, fb fr.Element
	fa.SetBigInt(a)
	fb.SetBigInt(b)

	h := mimc.NewMiMC()
	ab := fa.Bytes()
	bb := fb.Bytes()
	h.Write(ab[:])
	h.Write(bb[:])
	sum := h.Sum(nil)

	result := new(big.Int).SetBytes(sum)
	return result
}

func parseBigIntOrZero(s string) interface{} {
	n := new(big.Int)
	if len(s) > 2 && s[:2] == "0x" {
		n.SetString(s[2:], 16)
	} else {
		n.SetString(s, 10)
	}
	return n
}
