package cmd

import (
	"fmt"
	"log/slog"
	"os"
	"os/signal"
	"strings"
	"syscall"

	"github.com/spf13/cobra"
	"github.com/spf13/viper"
	"github.com/sunbk201/ua3f/internal/api"
	"github.com/sunbk201/ua3f/internal/config"
	"github.com/sunbk201/ua3f/internal/daemon"
	"github.com/sunbk201/ua3f/internal/log"
	"github.com/sunbk201/ua3f/internal/server"
	"github.com/sunbk201/ua3f/internal/server/desync"
	"github.com/sunbk201/ua3f/internal/server/netlink"
)

var AppVersion = "Development"

var rootCmd = &cobra.Command{
	Use:          "ua3f",
	Short:        "Advanced HTTP Rewriting Tool",
	SilenceUsage: true,
	RunE:         runRoot,
}

func Execute() {
	if err := rootCmd.Execute(); err != nil {
		os.Exit(1)
	}
}

func init() {
	cobra.OnInitialize(initConfig)

	// Short flags
	rootCmd.Flags().StringP("config", "c", "", "Config file path")
	rootCmd.Flags().StringP("mode", "m", "", "Server mode: HTTP, SOCKS5, TPROXY, REDIRECT, NFQUEUE")
	rootCmd.Flags().StringP("bind", "b", "", "Bind address")
	rootCmd.Flags().IntP("port", "p", 0, "Port")
	rootCmd.Flags().StringP("log-level", "l", "", "Log level")
	rootCmd.Flags().StringP("ua", "f", "", "User-Agent")
	rootCmd.Flags().StringP("ua-regex", "r", "", "User-Agent regex")
	rootCmd.Flags().BoolP("partial", "s", false, "Enable regex partial replace")
	rootCmd.Flags().StringP("rewrite-mode", "x", "", "Rewrite mode: GLOBAL, DIRECT, RULE")
	rootCmd.Flags().String("api-server", "", "api-server listen address (e.g. 127.0.0.1:9000), empty to disable")
	rootCmd.Flags().String("api-server-secret", "", "api-server secret for authentication, empty to disable auth")
	rootCmd.Flags().BoolP("version", "v", false, "Show version")
	rootCmd.Flags().BoolP("generate-config", "g", false, "Generate template config file")

	rootCmd.Flags().Bool("include-lan-routes", false, "Include LAN routes from proxying")

	// Long flags
	rootCmd.Flags().String("header-rewrite", "", "Header rewrite json rules")
	rootCmd.Flags().String("body-rewrite", "", "Body rewrite json rules")
	rootCmd.Flags().String("url-redirect", "", "URL redirect json rules")

	rootCmd.Flags().Bool("ttl", false, "Set TTL")
	rootCmd.Flags().Bool("ipid", false, "Set IP ID")
	rootCmd.Flags().Bool("tcpts", false, "Delete TCP Timestamp")
	rootCmd.Flags().Bool("tcpwin", false, "Set TCP Initial Window")
	rootCmd.Flags().Bool("block-quic", false, "Block QUIC by dropping outbound UDP/443 traffic")

	rootCmd.Flags().Bool("l3-rewrite-ttl", false, "Set TTL (legacy flag, same as --ttl)")
	rootCmd.Flags().Bool("l3-rewrite-ipid", false, "Set IP ID (legacy flag, same as --ipid)")
	rootCmd.Flags().Bool("l3-rewrite-tcpts", false, "Delete TCP Timestamp (legacy flag, same as --tcpts)")
	rootCmd.Flags().Bool("l3-rewrite-tcpwin", false, "Set TCP Initial Window (legacy flag, same as --tcpwin)")
	rootCmd.Flags().Bool("l3-rewrite-bpf-offload", false, "Enable BPF offloading for L3 rewrite (requires kernel support)")

	rootCmd.Flags().Bool("desync-reorder", false, "Enable desync reorder")
	rootCmd.Flags().Uint("desync-reorder-bytes", 0, "Desync reorder bytes")
	rootCmd.Flags().Uint("desync-reorder-packets", 0, "Desync reorder packets")
	rootCmd.Flags().Bool("desync-inject", false, "Enable desync inject")
	rootCmd.Flags().Uint("desync-inject-ttl", 0, "Desync inject TTL")
	rootCmd.Flags().String("desync-ports", "", "Desync ports")

	// MitM flags
	rootCmd.Flags().Bool("mitm", false, "Enable HTTPS MitM")
	rootCmd.Flags().String("mitm-hostname", "", "MitM hostname list (comma-separated, supports wildcard * and :port suffix)")
	rootCmd.Flags().String("mitm-ca-p12", "", "Path to MitM CA PKCS#12 file")
	rootCmd.Flags().String("mitm-ca-p12-base64", "", "Base64-encoded PKCS#12 data for MitM CA")
	rootCmd.Flags().String("mitm-ca-passphrase", "", "Passphrase for MitM CA PKCS#12 file")
	rootCmd.Flags().Bool("mitm-insecure-skip-verify", false, "Skip server certificate verification in MitM")

	// BPF
	rootCmd.Flags().Bool("bpf-offload", false, "Enable BPF offloading (requires kernel support)")

	// Bind all flags to viper using consistent key names
	_ = viper.BindPFlag("config", rootCmd.Flags().Lookup("config"))
	_ = viper.BindPFlag("server-mode", rootCmd.Flags().Lookup("mode"))
	_ = viper.BindPFlag("bind-address", rootCmd.Flags().Lookup("bind"))
	_ = viper.BindPFlag("port", rootCmd.Flags().Lookup("port"))
	_ = viper.BindPFlag("log-level", rootCmd.Flags().Lookup("log-level"))
	_ = viper.BindPFlag("user-agent", rootCmd.Flags().Lookup("ua"))
	_ = viper.BindPFlag("user-agent-regex", rootCmd.Flags().Lookup("ua-regex"))
	_ = viper.BindPFlag("user-agent-partial-replace", rootCmd.Flags().Lookup("partial"))
	_ = viper.BindPFlag("rewrite-mode", rootCmd.Flags().Lookup("rewrite-mode"))
	_ = viper.BindPFlag("header-rewrite-json", rootCmd.Flags().Lookup("header-rewrite"))
	_ = viper.BindPFlag("body-rewrite-json", rootCmd.Flags().Lookup("body-rewrite"))
	_ = viper.BindPFlag("url-redirect-json", rootCmd.Flags().Lookup("url-redirect"))

	_ = viper.BindPFlag("include-lan-routes", rootCmd.Flags().Lookup("include-lan-routes"))

	_ = viper.BindPFlag("ttl", rootCmd.Flags().Lookup("ttl"))
	_ = viper.BindPFlag("ipid", rootCmd.Flags().Lookup("ipid"))
	_ = viper.BindPFlag("tcp_timestamp", rootCmd.Flags().Lookup("tcpts"))
	_ = viper.BindPFlag("tcp_initial_window", rootCmd.Flags().Lookup("tcpwin"))

	_ = viper.BindPFlag("l3-rewrite.bpf-offload", rootCmd.Flags().Lookup("l3-rewrite-bpf-offload"))
	_ = viper.BindPFlag("l3-rewrite.ttl", rootCmd.Flags().Lookup("l3-rewrite-ttl"))
	_ = viper.BindPFlag("l3-rewrite.ipid", rootCmd.Flags().Lookup("l3-rewrite-ipid"))
	_ = viper.BindPFlag("l3-rewrite.tcpts", rootCmd.Flags().Lookup("l3-rewrite-tcpts"))
	_ = viper.BindPFlag("l3-rewrite.tcpwin", rootCmd.Flags().Lookup("l3-rewrite-tcpwin"))
	_ = viper.BindPFlag("l3-rewrite.block-quic", rootCmd.Flags().Lookup("block-quic"))

	_ = viper.BindPFlag("desync.reorder", rootCmd.Flags().Lookup("desync-reorder"))
	_ = viper.BindPFlag("desync.reorder-bytes", rootCmd.Flags().Lookup("desync-reorder-bytes"))
	_ = viper.BindPFlag("desync.reorder-packets", rootCmd.Flags().Lookup("desync-reorder-packets"))
	_ = viper.BindPFlag("desync.inject", rootCmd.Flags().Lookup("desync-inject"))
	_ = viper.BindPFlag("desync.inject-ttl", rootCmd.Flags().Lookup("desync-inject-ttl"))
	_ = viper.BindPFlag("desync.desync-ports", rootCmd.Flags().Lookup("desync-ports"))

	_ = viper.BindPFlag("api-server", rootCmd.Flags().Lookup("api-server"))
	_ = viper.BindPFlag("api-server-secret", rootCmd.Flags().Lookup("api-server-secret"))

	_ = viper.BindPFlag("mitm.enabled", rootCmd.Flags().Lookup("mitm"))
	_ = viper.BindPFlag("mitm.hostname", rootCmd.Flags().Lookup("mitm-hostname"))
	_ = viper.BindPFlag("mitm.ca-p12", rootCmd.Flags().Lookup("mitm-ca-p12"))
	_ = viper.BindPFlag("mitm.ca-p12-base64", rootCmd.Flags().Lookup("mitm-ca-p12-base64"))
	_ = viper.BindPFlag("mitm.ca-passphrase", rootCmd.Flags().Lookup("mitm-ca-passphrase"))
	_ = viper.BindPFlag("mitm.insecure-skip-verify", rootCmd.Flags().Lookup("mitm-insecure-skip-verify"))

	_ = viper.BindPFlag("bpf-offload", rootCmd.Flags().Lookup("bpf-offload"))

	// Bind environment variables
	viper.SetEnvPrefix("UA3F")
	viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_", ".", "_"))
	viper.AutomaticEnv()

	// Map specific env vars to viper keys for backward compatibility
	_ = viper.BindEnv("server-mode", "UA3F_SERVER_MODE")
	_ = viper.BindEnv("bind-address", "UA3F_BIND_ADDRESS")
	_ = viper.BindEnv("port", "UA3F_PORT")
	_ = viper.BindEnv("log-level", "UA3F_LOG_LEVEL")
	_ = viper.BindEnv("rewrite-mode", "UA3F_REWRITE_MODE")
	_ = viper.BindEnv("user-agent", "UA3F_PAYLOAD_UA")
	_ = viper.BindEnv("user-agent-regex", "UA3F_UA_REGEX")
	_ = viper.BindEnv("user-agent-partial-replace", "UA3F_PARTIAL_REPLACE")

	_ = viper.BindEnv("include-lan-routes", "UA3F_INCLUDE_LAN_ROUTES")

	_ = viper.BindEnv("ttl", "UA3F_TTL")
	_ = viper.BindEnv("ipid", "UA3F_IPID")
	_ = viper.BindEnv("tcp_timestamp", "UA3F_TCPTS")
	_ = viper.BindEnv("tcp_initial_window", "UA3F_TCP_INIT_WINDOW")

	_ = viper.BindEnv("l3-rewrite.bpf-offload", "UA3F_L3_REWRITE_BPF_OFFLOAD")
	_ = viper.BindEnv("l3-rewrite.ttl", "UA3F_L3_REWRITE_TTL")
	_ = viper.BindEnv("l3-rewrite.ipid", "UA3F_L3_REWRITE_IPID")
	_ = viper.BindEnv("l3-rewrite.tcpts", "UA3F_L3_REWRITE_TCPTS")
	_ = viper.BindEnv("l3-rewrite.tcpwin", "UA3F_L3_REWRITE_TCPWIN")
	_ = viper.BindEnv("l3-rewrite.block-quic", "UA3F_L3_REWRITE_BLOCK_QUIC")

	_ = viper.BindEnv("desync.reorder", "UA3F_DESYNC_REORDER")
	_ = viper.BindEnv("desync.reorder-bytes", "UA3F_DESYNC_REORDER_BYTES")
	_ = viper.BindEnv("desync.reorder-packets", "UA3F_DESYNC_REORDER_PACKETS")
	_ = viper.BindEnv("desync.inject", "UA3F_DESYNC_INJECT")
	_ = viper.BindEnv("desync.inject-ttl", "UA3F_DESYNC_INJECT_TTL")
	_ = viper.BindEnv("desync.desync-ports", "UA3F_DESYNC_PORTS")

	_ = viper.BindEnv("mitm.enabled", "UA3F_MITM_ENABLED")
	_ = viper.BindEnv("mitm.hostname", "UA3F_MITM_HOSTNAME")
	_ = viper.BindEnv("mitm.ca-p12", "UA3F_MITM_CA_P12")
	_ = viper.BindEnv("mitm.ca-p12-base64", "UA3F_MITM_CA_P12_BASE64")
	_ = viper.BindEnv("mitm.ca-passphrase", "UA3F_MITM_CA_PASSPHRASE")
	_ = viper.BindEnv("mitm.insecure-skip-verify", "UA3F_MITM_INSECURE_SKIP_VERIFY")

	_ = viper.BindEnv("bpf-offload", "UA3F_BPF_OFFLOAD")

	_ = viper.BindEnv("api-server", "UA3F_API_SERVER")
	_ = viper.BindEnv("api-server-secret", "UA3F_API_SERVER_SECRET")

	_ = viper.BindEnv("header-rewrite-json", "UA3F_HEADER_REWRITE")
	_ = viper.BindEnv("body-rewrite-json", "UA3F_BODY_REWRITE")
	_ = viper.BindEnv("url-redirect-json", "UA3F_URL_REDIRECT")
}

func initConfig() {
	configFile := viper.GetString("config")
	if configFile != "" {
		viper.SetConfigFile(configFile)
		if err := viper.MergeInConfig(); err != nil {
			slog.Error("Failed to read config file", slog.Any("error", err))
			os.Exit(1)
		}
	}

	viper.SetDefault("server-mode", "SOCKS5")
	viper.SetDefault("bind-address", "127.0.0.1")
	viper.SetDefault("port", 1080)
	viper.SetDefault("log-level", "info")
	viper.SetDefault("user-agent", "FFF")
	viper.SetDefault("rewrite-mode", "GLOBAL")

	viper.SetDefault("desync.reorder-bytes", 8)
	viper.SetDefault("desync.reorder-packets", 1500)
	viper.SetDefault("desync.inject-ttl", 3)
}

func runRoot(cmd *cobra.Command, args []string) error {
	// Handle -v / --version
	showVer, _ := cmd.Flags().GetBool("version")
	if showVer {
		fmt.Printf("UA3F version %s\n", AppVersion)
		return nil
	}

	// Handle -g / --generate-config
	genConfig, _ := cmd.Flags().GetBool("generate-config")
	if genConfig {
		_, err := config.GenerateTemplateConfig(true)
		if err != nil {
			return fmt.Errorf("failed to generate template config: %w", err)
		}
		fmt.Println("Template config file 'config.yaml' generated successfully.")
		return nil
	}

	// Build config from flags/env/config file
	cfg, err := config.BuildConfigFromViper()
	if err != nil {
		return fmt.Errorf("config error: %w", err)
	}

	// Set up logging and log the configuration
	logBroadcaster, err := log.SetLogConf(cfg.LogLevel)
	if err != nil {
		return fmt.Errorf("log setup error: %w", err)
	}
	log.LogHeader(AppVersion, cfg)

	if err := daemon.DaemonSetup(cfg); err != nil {
		slog.Error("daemon.DaemonSetup", slog.Any("error", err))
		return err
	}

	// Start api server
	apiSrv := api.New(AppVersion, cfg, logBroadcaster)
	if err := apiSrv.Start(); err != nil {
		slog.Error("apiSrv.Start", slog.Any("error", err))
		apiSrv.CloseSystem()
		return err
	}

	// Start packet modification helper
	apiSrv.Helper = netlink.New(cfg)
	if err := apiSrv.Helper.Start(); err != nil {
		slog.Error("helper.Start", slog.Any("error", err))
		apiSrv.CloseSystem()
		return err
	}

	// Start desync server if enabled
	if cfg.Desync.Reorder || cfg.Desync.Inject {
		apiSrv.Desync = desync.New(cfg)
		if err := apiSrv.Desync.Start(); err != nil {
			slog.Error("desync.Start", slog.Any("error", err))
			apiSrv.CloseSystem()
			return err
		}
	}

	// Start main server
	srv, err := server.NewServer(cfg)
	if err != nil {
		slog.Error("server.NewServer", slog.Any("error", err))
		apiSrv.CloseSystem()
		return err
	}
	apiSrv.Server = srv
	if err := srv.Start(); err != nil {
		slog.Error("srv.Start", slog.Any("error", err))
		apiSrv.CloseSystem()
		return err
	}

	cleanup := make(chan os.Signal, 1)
	signal.Notify(cleanup, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGINT, syscall.SIGTERM)
	for {
		s := <-cleanup
		slog.Info("Received signal", slog.String("signal", s.String()))
		switch s {
		case syscall.SIGQUIT, syscall.SIGINT, syscall.SIGTERM:
			apiSrv.CloseSystem()
			return nil
		case syscall.SIGHUP:
			if err := apiSrv.RestartSystem(); err != nil {
				slog.Error("Failed to restart ua3f", slog.Any("error", err))
			}
		default:
			return nil
		}
	}
}
