//go:build linux

package netlink

import (
	"log/slog"

	nfq "github.com/florianl/go-nfqueue/v2"
	"github.com/google/gopacket/layers"
	"github.com/sunbk201/ua3f/internal/bpf/tc"
	"github.com/sunbk201/ua3f/internal/common"
	"github.com/sunbk201/ua3f/internal/config"
	"github.com/sunbk201/ua3f/internal/netfilter"
	"github.com/sunbk201/ua3f/internal/server/base"
	"sigs.k8s.io/knftables"
)

type Server struct {
	netfilter.Firewall
	cfg       *config.L3RewriteConfig
	mainCfg   *config.Config
	nfqServer *base.NfqueueServer
	tc        *tc.TC
}

func New(cfg *config.Config) *Server {
	s := &Server{
		cfg:     &cfg.L3Rewrite,
		mainCfg: cfg,
		nfqServer: &base.NfqueueServer{
			QueueNum: netfilter.HELPER_QUEUE,
		},
	}
	s.nfqServer.HandlePacket = s.handlePacket
	s.Firewall = netfilter.Firewall{
		Nftable: &knftables.Table{
			Name:   "UA3F_HELPER",
			Family: knftables.InetFamily,
		},
		NftSetup:   s.nftSetup,
		NftCleanup: s.nftCleanup,
		IptSetup:   s.iptSetup,
		IptCleanup: s.iptCleanup,
	}
	return s
}

func (s *Server) Start() (err error) {
	enableL3Rewrite := s.cfg.TTL || s.cfg.TCPTS || s.cfg.TCPWIN || s.cfg.IPID || s.cfg.BLOCKQUIC
	if !enableL3Rewrite {
		return nil
	}

	if s.cfg.BPFOffload {
		if s.tc, err = tc.NewTC(s.cfg); err != nil {
			slog.Error("initialize BPF TC failed, please try disable BPF offload", slog.Any("error", err))
			return err
		} else {
			slog.Info("BPF TC initialized successfully")
			return nil
		}
	}

	err = s.Firewall.Setup(s.mainCfg)
	if err != nil {
		slog.Error("s.Firewall.Setup", slog.Any("error", err))
		return err
	}
	slog.Info("Packet modification configuration", slog.Bool("ttl", s.cfg.TTL), slog.Bool("tcpts", s.cfg.TCPTS), slog.Bool("ipid", s.cfg.IPID), slog.Bool("tcp_init_window", s.cfg.TCPWIN), slog.Bool("block_quic", s.cfg.BLOCKQUIC))
	if s.cfg.TCPTS || s.cfg.TCPWIN || s.cfg.IPID {
		return s.nfqServer.Start()
	}
	return nil
}

func (s *Server) Close() error {
	s.tc.Close()
	err := s.Firewall.Cleanup()
	s.nfqServer.Close()
	return err
}

func (s *Server) Restart(cfg *config.Config) (*Server, error) {
	if err := s.Close(); err != nil {
		return nil, err
	}

	newServer := New(cfg)
	if err := newServer.Start(); err != nil {
		return nil, err
	}
	return newServer, nil
}

// handlePacket processes a single NFQUEUE packet
func (s *Server) handlePacket(packet *common.Packet) {
	nf := s.nfqServer.Nf

	modified := false
	if packet.TCP != nil {
		if s.cfg.TCPTS {
			modified = s.clearTCPTimestamp(packet.TCP) || modified
		}
		if s.cfg.TCPWIN {
			modified = s.setInitialTCPWindow(packet.TCP) || modified
		}
	}
	if s.cfg.IPID {
		modified = s.zeroIPID(packet) || modified
	}

	if modified {
		newPacket, err := packet.Serialize()
		if err != nil {
			slog.Error("packet.Serialize", slog.Any("error", err))
			_ = nf.SetVerdict(*packet.A.PacketID, nfq.NfAccept)
			return
		}
		if err := nf.SetVerdictWithOption(*packet.A.PacketID, nfq.NfAccept, nfq.WithAlteredPacket(newPacket)); err != nil {
			slog.Error("nf.SetVerdictWithOption", slog.Any("error", err))
			_ = nf.SetVerdict(*packet.A.PacketID, nfq.NfAccept)
		}
	} else {
		_ = nf.SetVerdict(*packet.A.PacketID, nfq.NfAccept)
	}
}

// clearTCPTimestamp removes the TCP timestamp option from the TCP layer
// Returns true if the timestamp option was found and removed
func (s *Server) clearTCPTimestamp(tcp *layers.TCP) bool {
	if len(tcp.Options) == 0 {
		return false
	}

	modified := false
	newOptions := make([]layers.TCPOption, 0, len(tcp.Options))

	for _, opt := range tcp.Options {
		if opt.OptionType == layers.TCPOptionKindTimestamps {
			modified = true
			continue
		}
		if opt.OptionType == layers.TCPOptionKindNop || opt.OptionType == layers.TCPOptionKindEndList {
			continue
		}
		newOptions = append(newOptions, opt)
	}
	if modified {
		tcp.Options = newOptions
	}
	return modified
}

// setInitialTCPWindow sets the TCP initial window size to 65535 for SYN packets
func (s *Server) setInitialTCPWindow(tcp *layers.TCP) bool {
	if !(tcp.SYN && !tcp.ACK) {
		return false
	}
	if tcp.Window == uint16(65535) {
		return false
	}
	tcp.Window = uint16(65535)
	return true
}

// zeroIPID sets the IP ID field to zero for IPv4 packets
// Returns true if the packet was modified
func (s *Server) zeroIPID(packet *common.Packet) bool {
	if packet.IsIPv6 {
		return false
	}
	ip4 := packet.NetworkLayer.(*layers.IPv4)
	if ip4.Id == 0 {
		return false
	}
	ip4.Id = 0
	return true
}
