package base

import (
	"context"
	"encoding/binary"
	"errors"
	"fmt"
	"hash/fnv"
	"log/slog"
	"runtime"
	"strings"
	"sync"

	nfq "github.com/florianl/go-nfqueue/v2"
	"github.com/mdlayher/netlink"
	"github.com/sunbk201/ua3f/internal/common"
)

type NfqHandler func(a *common.Packet)

type NfqueueServer struct {
	HandlePacket  NfqHandler
	Nf            *nfq.Nfqueue
	cancel        context.CancelFunc
	attrChans     []chan *nfq.Attribute
	wg            sync.WaitGroup
	NumWorkers    int
	WorkerChanLen int
	MaxQueueLen   uint32
	MaxPacketLen  uint32
	QueueNum      uint16
}

func (s *NfqueueServer) Start() error {
	if s.QueueNum == 0 {
		return fmt.Errorf("NfqueueServer.QueueNum is 0")
	}
	if s.HandlePacket == nil {
		return fmt.Errorf("NfqueueServer.Handler is nil")
	}
	if s.MaxQueueLen <= 0 {
		s.MaxQueueLen = 4000
	}
	if s.MaxPacketLen <= 0 {
		s.MaxPacketLen = 1600
	}
	if s.NumWorkers <= 0 {
		s.NumWorkers = runtime.NumCPU()
		if s.NumWorkers < 2 {
			s.NumWorkers = 2
		}
	}
	if s.WorkerChanLen <= 0 {
		s.WorkerChanLen = 2000
	}
	config := nfq.Config{
		NfQueue:      s.QueueNum,
		MaxQueueLen:  s.MaxQueueLen,
		MaxPacketLen: s.MaxPacketLen,
		Copymode:     nfq.NfQnlCopyPacket,
		Flags:        nfq.NfQaCfgFlagConntrack,
	}

	nf, err := nfq.Open(&config)
	if err != nil {
		return fmt.Errorf("nfq.Open: %w", err)
	}
	s.Nf = nf

	// Ignore ENOBUFS to prevent queue drop logs
	// if err := nf.SetOption(netlink.NoENOBUFS, true); err != nil {
	//	return fmt.Errorf("nf.SetOption: %w", err)
	// }

	err = nf.Con.SetReadBuffer(1024 * 1024 * 2)
	if err != nil {
		slog.Error("nf.Con.SetReadBuffer", slog.Any("error", err))
	}

	ctx, cancel := context.WithCancel(context.Background())
	s.cancel = cancel

	// Initialize worker channels and start worker goroutines
	s.attrChans = make([]chan *nfq.Attribute, s.NumWorkers)
	for i := 0; i < s.NumWorkers; i++ {
		s.attrChans[i] = make(chan *nfq.Attribute, s.WorkerChanLen)
		s.wg.Add(1)
		go s.worker(i, s.attrChans[i])
	}

	// Register callback function
	err = nf.RegisterWithErrorFunc(ctx,
		func(a nfq.Attribute) int {
			select {
			case s.attrChans[s.computeWorkerIndex(&a)] <- &a:
			default:
				// If worker channel is full, accept the packet to avoid blocking
				slog.Warn("Worker channel full, accepting packet without processing")
				if a.PacketID != nil {
					_ = nf.SetVerdict(*a.PacketID, nfq.NfAccept)
				}
			}
			return 0
		},
		func(e error) int {
			if strings.Contains(e.Error(), "no buffer space available") {
				slog.Warn("No buffer space available, consider increasing the read buffer size to prevent packet drops")
				err = nf.Con.SetReadBuffer(1024 * 1024 * 5)
				if err != nil {
					slog.Error("nf.Con.SetReadBuffer", slog.Any("error", err))
				}
			} else if errors.Is(ctx.Err(), context.Canceled) {
				slog.Info("Nfqueue context canceled, stopping nfqueue handler")
			} else {
				slog.Error("Error in nfqueue handler", slog.Any("error", e))
			}
			return 0
		},
	)
	if err != nil {
		return fmt.Errorf("nf.RegisterWithErrorFunc: %w", err)
	}
	return nil
}

func (s *NfqueueServer) Close() {
	if s.cancel != nil {
		s.cancel()
	}

	for i := 0; i < len(s.attrChans); i++ {
		if s.attrChans[i] != nil {
			close(s.attrChans[i])
		}
	}

	s.wg.Wait()

	if s.Nf != nil {
		_ = s.Nf.Close()
	}
}

// worker processes packets from its assigned channel
func (s *NfqueueServer) worker(workerID int, aChan <-chan *nfq.Attribute) {
	defer s.wg.Done()

	for a := range aChan {
		if ok := attributeSanityCheck(a); !ok {
			if a.PacketID != nil {
				_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
			}
			slog.Warn("Invalid nfq.Attribute received", slog.Int("workerID", workerID))
			return
		}
		packet, err := common.NewPacket(a)
		if err != nil {
			slog.Error("NewPacket", slog.Int("workerID", workerID), slog.Any("error", err))
			if a.PacketID != nil {
				_ = s.Nf.SetVerdict(*a.PacketID, nfq.NfAccept)
			}
			continue
		}
		slog.Debug("Processing packet", slog.Int("workerID", workerID), slog.String("srcAddr", packet.SrcAddr), slog.String("dstAddr", packet.DstAddr))
		s.HandlePacket(packet)
	}
}

func (s *NfqueueServer) computeWorkerIndex(a *nfq.Attribute) int {
	var flowID uint32
	if a.Ct != nil {
		flowID = ctIDFromCtBytes(*a.Ct)
	} else {
		// Compute flow hash to determine which worker should handle this packet
		flowID = computeFlowHash(*a.Payload)
	}
	workerIdx := int(flowID % uint32(s.NumWorkers))
	return workerIdx
}

// computeFlowHash computes a hash value based on TCP 4-tuple to ensure packets
// from the same TCP stream are handled by the same worker goroutine
func computeFlowHash(pktData []byte) uint32 {
	version := (pktData[0] >> 4) & 0xF

	h := fnv.New32a()

	switch version {
	case 4:
		// IPv4: IP header is at least 20 bytes
		if len(pktData) < 20 {
			return 0
		}

		// Source IP (bytes 12-15) and Dest IP (bytes 16-19)
		h.Write(pktData[12:20])

		// Check if it's TCP (protocol 6)
		protocol := pktData[9]
		if protocol == 6 {
			// IHL (IP Header Length) is in the lower 4 bits of byte 0
			ihl := (pktData[0] & 0x0F) * 4
			if len(pktData) >= int(ihl)+4 {
				// TCP source port and dest port (first 4 bytes of TCP header)
				h.Write(pktData[ihl : ihl+4])
			}
		}

	case 6:
		// IPv6: IP header is at least 40 bytes
		if len(pktData) < 40 {
			return 0
		}

		// Source IP (bytes 8-23) and Dest IP (bytes 24-39)
		h.Write(pktData[8:40])

		// Check if it's TCP (next header 6)
		nextHeader := pktData[6]
		if nextHeader == 6 && len(pktData) >= 44 {
			// TCP source port and dest port (first 4 bytes of TCP header at offset 40)
			h.Write(pktData[40:44])
		}
	}

	return h.Sum32()
}

func ctIDFromCtBytes(ct []byte) uint32 {
	ctAttrs, err := netlink.UnmarshalAttributes(ct)
	if err != nil {
		return 0
	}
	for _, attr := range ctAttrs {
		if attr.Type == 12 { // CTA_ID
			return binary.BigEndian.Uint32(attr.Data)
		}
	}
	return 0
}

func attributeSanityCheck(a *nfq.Attribute) (ok bool) {
	if a.PacketID == nil {
		return false
	}
	if a.Payload == nil || len(*a.Payload) < 20 {
		return false
	}
	return true
}
