package prover

import (
	"fmt"
	"time"

	"github.com/consensys/gnark/frontend"
	goprover "github.com/pflow-xyz/go-pflow/prover"
	"github.com/rs/zerolog/log"
)

// ArcnetWitnessFactory creates circuit assignments for arcnet's domain-specific circuits.
type ArcnetWitnessFactory struct{}

// CreateAssignment implements goprover.WitnessFactory.
func (f *ArcnetWitnessFactory) CreateAssignment(circuitName string, witness map[string]string) (frontend.Circuit, error) {
	switch circuitName {
	case "transfer":
		var err error
		var pre, post, from, to, amount, balanceFrom, balanceTo frontend.Variable
		var pathElems, pathIdx [20]frontend.Variable
		if pre, err = goprover.ParseWitnessField(witness, "preStateRoot"); err != nil {
			return nil, err
		}
		if post, err = goprover.ParseWitnessField(witness, "postStateRoot"); err != nil {
			return nil, err
		}
		if from, err = goprover.ParseWitnessField(witness, "from"); err != nil {
			return nil, err
		}
		if to, err = goprover.ParseWitnessField(witness, "to"); err != nil {
			return nil, err
		}
		if amount, err = goprover.ParseWitnessField(witness, "amount"); err != nil {
			return nil, err
		}
		if balanceFrom, err = goprover.ParseWitnessField(witness, "balanceFrom"); err != nil {
			return nil, err
		}
		if balanceTo, err = goprover.ParseWitnessField(witness, "balanceTo"); err != nil {
			return nil, err
		}
		for i := 0; i < 20; i++ {
			if pathElems[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("pathElement%d", i)); err != nil {
				return nil, err
			}
			if pathIdx[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("pathIndex%d", i)); err != nil {
				return nil, err
			}
		}
		if circuitName == "transfer" {
			return &TransferCircuit{
				PreStateRoot: pre, PostStateRoot: post, From: from, To: to, Amount: amount,
				BalanceFrom: balanceFrom, BalanceTo: balanceTo,
				PathElements: pathElems, PathIndices: pathIdx,
			}, nil
		}
		return &TransferCircuit{
			PreStateRoot: pre, PostStateRoot: post, From: from, To: to, Amount: amount,
			BalanceFrom: balanceFrom, BalanceTo: balanceTo,
			PathElements: pathElems, PathIndices: pathIdx,
		}, nil

	case "mint":
		a := &MintCircuit{}
		var err error
		if a.PreStateRoot, err = goprover.ParseWitnessField(witness, "preStateRoot"); err != nil {
			return nil, err
		}
		if a.PostStateRoot, err = goprover.ParseWitnessField(witness, "postStateRoot"); err != nil {
			return nil, err
		}
		if a.Caller, err = goprover.ParseWitnessField(witness, "caller"); err != nil {
			return nil, err
		}
		if a.To, err = goprover.ParseWitnessField(witness, "to"); err != nil {
			return nil, err
		}
		if a.Amount, err = goprover.ParseWitnessField(witness, "amount"); err != nil {
			return nil, err
		}
		if a.Minter, err = goprover.ParseWitnessField(witness, "minter"); err != nil {
			return nil, err
		}
		if a.BalanceTo, err = goprover.ParseWitnessField(witness, "balanceTo"); err != nil {
			return nil, err
		}
		return a, nil

	case "burn":
		// Shared witness schema for hand-written and synthesized Burn.
		var err error
		var pre, post, from, amount, balanceFrom frontend.Variable
		var pathElems, pathIdx [20]frontend.Variable
		if pre, err = goprover.ParseWitnessField(witness, "preStateRoot"); err != nil {
			return nil, err
		}
		if post, err = goprover.ParseWitnessField(witness, "postStateRoot"); err != nil {
			return nil, err
		}
		if from, err = goprover.ParseWitnessField(witness, "from"); err != nil {
			return nil, err
		}
		if amount, err = goprover.ParseWitnessField(witness, "amount"); err != nil {
			return nil, err
		}
		if balanceFrom, err = goprover.ParseWitnessField(witness, "balanceFrom"); err != nil {
			return nil, err
		}
		for i := 0; i < 20; i++ {
			if pathElems[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("pathElement%d", i)); err != nil {
				return nil, err
			}
			if pathIdx[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("pathIndex%d", i)); err != nil {
				return nil, err
			}
		}
		if circuitName == "burn" {
			return &BurnCircuit{
				PreStateRoot: pre, PostStateRoot: post, From: from, Amount: amount,
				BalanceFrom: balanceFrom, PathElements: pathElems, PathIndices: pathIdx,
			}, nil
		}
		return &BurnCircuit{
			PreStateRoot: pre, PostStateRoot: post, From: from, Amount: amount,
			BalanceFrom: balanceFrom, PathElements: pathElems, PathIndices: pathIdx,
		}, nil

	case "approve":
		var pre, post, caller, spender, amount, owner frontend.Variable
		var err error
		if pre, err = goprover.ParseWitnessField(witness, "preStateRoot"); err != nil {
			return nil, err
		}
		if post, err = goprover.ParseWitnessField(witness, "postStateRoot"); err != nil {
			return nil, err
		}
		if caller, err = goprover.ParseWitnessField(witness, "caller"); err != nil {
			return nil, err
		}
		if spender, err = goprover.ParseWitnessField(witness, "spender"); err != nil {
			return nil, err
		}
		if amount, err = goprover.ParseWitnessField(witness, "amount"); err != nil {
			return nil, err
		}
		if owner, err = goprover.ParseWitnessField(witness, "owner"); err != nil {
			return nil, err
		}
		if circuitName == "approve" {
			return &ApproveCircuit{
				PreStateRoot: pre, PostStateRoot: post, Caller: caller,
				Spender: spender, Amount: amount, Owner: owner,
			}, nil
		}
		return &ApproveCircuit{
			PreStateRoot: pre, PostStateRoot: post, Caller: caller,
			Spender: spender, Amount: amount, Owner: owner,
		}, nil

	case "transferFrom":
		// Shared witness schema — same field names and dimensions.
		var pre, post, from, to, caller, amount, balanceFrom, allowanceFrom frontend.Variable
		var balPath, balIdx, allowPath, allowIdx [10]frontend.Variable
		var err error
		if pre, err = goprover.ParseWitnessField(witness, "preStateRoot"); err != nil {
			return nil, err
		}
		if post, err = goprover.ParseWitnessField(witness, "postStateRoot"); err != nil {
			return nil, err
		}
		if from, err = goprover.ParseWitnessField(witness, "from"); err != nil {
			return nil, err
		}
		if to, err = goprover.ParseWitnessField(witness, "to"); err != nil {
			return nil, err
		}
		if caller, err = goprover.ParseWitnessField(witness, "caller"); err != nil {
			return nil, err
		}
		if amount, err = goprover.ParseWitnessField(witness, "amount"); err != nil {
			return nil, err
		}
		if balanceFrom, err = goprover.ParseWitnessField(witness, "balanceFrom"); err != nil {
			return nil, err
		}
		if allowanceFrom, err = goprover.ParseWitnessField(witness, "allowanceFrom"); err != nil {
			return nil, err
		}
		for i := 0; i < 10; i++ {
			if balPath[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("balancePath%d", i)); err != nil {
				return nil, err
			}
			if balIdx[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("balanceIndex%d", i)); err != nil {
				return nil, err
			}
			if allowPath[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("allowancePath%d", i)); err != nil {
				return nil, err
			}
			if allowIdx[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("allowanceIndex%d", i)); err != nil {
				return nil, err
			}
		}
		if circuitName == "transferFrom" {
			return &TransferFromCircuit{
				PreStateRoot: pre, PostStateRoot: post, From: from, To: to, Caller: caller, Amount: amount,
				BalanceFrom: balanceFrom, AllowanceFrom: allowanceFrom,
				BalancePath: balPath, BalanceIndices: balIdx,
				AllowancePath: allowPath, AllowanceIdx: allowIdx,
			}, nil
		}
		return &TransferFromCircuit{
			PreStateRoot: pre, PostStateRoot: post, From: from, To: to, Caller: caller, Amount: amount,
			BalanceFrom: balanceFrom, AllowanceFrom: allowanceFrom,
			BalancePath: balPath, BalanceIndices: balIdx,
			AllowancePath: allowPath, AllowanceIdx: allowIdx,
		}, nil

	case "vestClaim":
		var pre, post, tokenID, caller, claimAmount, vestedAmount, claimed, owner frontend.Variable
		var schedulePath, scheduleIdx, ownerPath, ownerIdx [10]frontend.Variable
		var err error
		if pre, err = goprover.ParseWitnessField(witness, "preStateRoot"); err != nil {
			return nil, err
		}
		if post, err = goprover.ParseWitnessField(witness, "postStateRoot"); err != nil {
			return nil, err
		}
		if tokenID, err = goprover.ParseWitnessField(witness, "tokenID"); err != nil {
			return nil, err
		}
		if caller, err = goprover.ParseWitnessField(witness, "caller"); err != nil {
			return nil, err
		}
		if claimAmount, err = goprover.ParseWitnessField(witness, "claimAmount"); err != nil {
			return nil, err
		}
		if vestedAmount, err = goprover.ParseWitnessField(witness, "vestedAmount"); err != nil {
			return nil, err
		}
		if claimed, err = goprover.ParseWitnessField(witness, "claimed"); err != nil {
			return nil, err
		}
		if owner, err = goprover.ParseWitnessField(witness, "owner"); err != nil {
			return nil, err
		}
		for i := 0; i < 10; i++ {
			if schedulePath[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("schedulePath%d", i)); err != nil {
				return nil, err
			}
			if scheduleIdx[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("scheduleIndex%d", i)); err != nil {
				return nil, err
			}
			if ownerPath[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("ownerPath%d", i)); err != nil {
				return nil, err
			}
			if ownerIdx[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("ownerIndex%d", i)); err != nil {
				return nil, err
			}
		}
		if circuitName == "vestClaim" {
			return &VestingClaimCircuit{
				PreStateRoot: pre, PostStateRoot: post, TokenID: tokenID, Caller: caller,
				ClaimAmount: claimAmount, VestedAmount: vestedAmount, Claimed: claimed, Owner: owner,
				SchedulePath: schedulePath, ScheduleIndices: scheduleIdx,
				OwnerPath: ownerPath, OwnerIndices: ownerIdx,
			}, nil
		}
		return &VestingClaimCircuit{
			PreStateRoot: pre, PostStateRoot: post, TokenID: tokenID, Caller: caller,
			ClaimAmount: claimAmount, VestedAmount: vestedAmount, Claimed: claimed, Owner: owner,
			SchedulePath: schedulePath, ScheduleIndices: scheduleIdx,
			OwnerPath: ownerPath, OwnerIndices: ownerIdx,
		}, nil

	case "voteCast":
		var pollID, registryRoot, nullifier, voteCommitment, maxChoices frontend.Variable
		var voterSecret, voteChoice, voterWeight frontend.Variable
		var pathElems, pathIdx [20]frontend.Variable
		var err error
		if pollID, err = goprover.ParseWitnessField(witness, "pollId"); err != nil {
			return nil, err
		}
		if registryRoot, err = goprover.ParseWitnessField(witness, "voterRegistryRoot"); err != nil {
			return nil, err
		}
		if nullifier, err = goprover.ParseWitnessField(witness, "nullifier"); err != nil {
			return nil, err
		}
		if voteCommitment, err = goprover.ParseWitnessField(witness, "voteCommitment"); err != nil {
			return nil, err
		}
		if maxChoices, err = goprover.ParseWitnessField(witness, "maxChoices"); err != nil {
			return nil, err
		}
		if voterSecret, err = goprover.ParseWitnessField(witness, "voterSecret"); err != nil {
			return nil, err
		}
		if voteChoice, err = goprover.ParseWitnessField(witness, "voteChoice"); err != nil {
			return nil, err
		}
		if voterWeight, err = goprover.ParseWitnessField(witness, "voterWeight"); err != nil {
			return nil, err
		}
		for i := 0; i < 20; i++ {
			if pathElems[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("pathElement%d", i)); err != nil {
				return nil, err
			}
			if pathIdx[i], err = goprover.ParseWitnessField(witness, fmt.Sprintf("pathIndex%d", i)); err != nil {
				return nil, err
			}
		}
		if circuitName == "voteCast" {
			return &VoteCastCircuit{
				PollID: pollID, VoterRegistryRoot: registryRoot, Nullifier: nullifier,
				VoteCommitment: voteCommitment, MaxChoices: maxChoices,
				VoterSecret: voterSecret, VoteChoice: voteChoice, VoterWeight: voterWeight,
				PathElements: pathElems, PathIndices: pathIdx,
			}, nil
		}
		return &VoteCastCircuit{
			PollID: pollID, VoterRegistryRoot: registryRoot, Nullifier: nullifier,
			VoteCommitment: voteCommitment, MaxChoices: maxChoices,
			VoterSecret: voterSecret, VoteChoice: voteChoice, VoterWeight: voterWeight,
			PathElements: pathElems, PathIndices: pathIdx,
		}, nil

	case "voteCastHomomorphic_8":
		return buildVoteCastHomomorphic8Assignment(witness)

	case "tallyDecrypt_8":
		return buildTallyDecrypt8Assignment(witness)

	default:
		return nil, fmt.Errorf("unknown circuit: %s", circuitName)
	}
}

// NewArcnetService creates a new prover service with arcnet's circuits and witness factory.
// If keyDir is non-empty, keys are persisted to disk for fast restarts.
func NewArcnetService(keyDir string) (*Service, *KeyStore, error) {
	p := NewProver()

	log.Info().Msg("Registering standard circuits...")
	start := time.Now()

	var ks *KeyStore
	if keyDir != "" {
		var err error
		ks, err = NewKeyStore(keyDir)
		if err != nil {
			return nil, nil, fmt.Errorf("failed to create keystore: %w", err)
		}

		circuits := standardCircuits()
		if err := RegisterWithKeyStore(p, ks, circuits); err != nil {
			return nil, nil, fmt.Errorf("failed to register circuits: %w", err)
		}
	} else {
		if err := RegisterStandardCircuits(p); err != nil {
			return nil, nil, fmt.Errorf("failed to register circuits: %w", err)
		}
	}

	log.Info().
		Dur("elapsed", time.Since(start)).
		Int("circuits", len(p.ListCircuits())).
		Bool("cached", ks != nil).
		Msg("Circuits registered")

	return goprover.NewService(p, &ArcnetWitnessFactory{}), ks, nil
}

// standardCircuits returns the circuit definitions (without compiling them).
func standardCircuits() map[string]frontend.Circuit {
	return map[string]frontend.Circuit{
		"transfer":              &TransferCircuit{},
		"transferFrom":          &TransferFromCircuit{},
		"mint":                  &MintCircuit{},
		"burn":                  &BurnCircuit{},
		"approve":               &ApproveCircuit{},
		"vestClaim":             &VestingClaimCircuit{},
		"voteCast":              &VoteCastCircuit{},
		"tallyProof":            &TallyProofCircuit16{},
		"tallyProof_16":         &TallyProofCircuit16{},
		"tallyProof_64":         &TallyProofCircuit64{},
		// v3 (homomorphic-tally) circuits. Listed here so the keystore
		// path persists their cs/pk/vk and the in-browser prover can
		// fetch them via /api/keys/{name}.* — without persisted keys
		// the WASM worker can't load anything to prove against.
		"voteCastHomomorphic_8": &VoteCastHomomorphicCircuit_8{},
		"tallyDecrypt_8":        &TallyDecryptCircuit_8{},
	}
}
