package server

import (
	"context"
	"fmt"
	"log/slog"
	"net/http"
	"time"

	"github.com/vaayne/mcphub/internal/client"
	"github.com/vaayne/mcphub/internal/config"
	"github.com/vaayne/mcphub/internal/tools"

	"github.com/modelcontextprotocol/go-sdk/mcp"
)

// TransportConfig holds transport configuration for the server
type TransportConfig struct {
	Type string // "stdio", "http", or "sse"
	Host string // Host to bind for HTTP/SSE
	Port int    // Port to bind for HTTP/SSE
}

// Server represents the MCP hub server
type Server struct {
	config          *config.Config
	logger          *slog.Logger
	mcpServer       *mcp.Server
	clientManager   *client.Manager
	builtinRegistry *tools.BuiltinToolRegistry
	toolCallTimeout time.Duration
	httpServer      *http.Server // for graceful shutdown of HTTP/SSE
}

// NewServer creates a new MCP hub server
func NewServer(cfg *config.Config, logger *slog.Logger) *Server {
	return &Server{
		config:          cfg,
		logger:          logger,
		toolCallTimeout: 60 * time.Second,
	}
}

// Start starts the MCP server with the specified transport
func (s *Server) Start(ctx context.Context, transportCfg TransportConfig) error {
	s.logger.Info("Starting MCP hub server",
		slog.String("transport", transportCfg.Type),
	)

	// Initialize client manager
	s.clientManager = client.NewManager(s.logger)

	// Initialize builtin tool registry
	s.builtinRegistry = tools.NewBuiltinToolRegistry(s.logger)

	// Register built-in tools
	s.registerBuiltinTools()

	// Connect to remote servers
	if err := s.connectToRemoteServers(); err != nil {
		return fmt.Errorf("failed to connect to remote servers: %w", err)
	}

	// Create MCP server
	s.mcpServer = mcp.NewServer(&mcp.Implementation{
		Name:    "hub",
		Version: "v1.0.0",
	}, nil)

	// Register all tools with the MCP server
	if err := s.registerAllTools(); err != nil {
		return fmt.Errorf("failed to register tools: %w", err)
	}

	// Start with the appropriate transport
	switch transportCfg.Type {
	case "stdio":
		return s.startStdio(ctx)
	case "http":
		return s.startHTTP(ctx, transportCfg)
	case "sse":
		return s.startSSE(ctx, transportCfg)
	default:
		return fmt.Errorf("unsupported transport type: %s", transportCfg.Type)
	}
}

// startStdio starts the server with stdio transport
func (s *Server) startStdio(ctx context.Context) error {
	s.logger.Info("Starting stdio transport")
	transport := &mcp.StdioTransport{}
	if err := s.mcpServer.Run(ctx, transport); err != nil {
		return fmt.Errorf("server failed: %w", err)
	}
	return nil
}

// startHTTP starts the server with StreamableHTTP transport
func (s *Server) startHTTP(ctx context.Context, cfg TransportConfig) error {
	addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
	s.logger.Info("Starting HTTP transport", slog.String("address", addr))

	handler := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server {
		return s.mcpServer
	}, nil)

	mux := http.NewServeMux()
	mux.Handle("/mcp", handler)

	s.httpServer = &http.Server{
		Addr:    addr,
		Handler: mux,
	}

	s.logger.Info("MCP Hub server running", slog.String("url", fmt.Sprintf("http://%s/mcp", addr)))

	if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
		return fmt.Errorf("HTTP server failed: %w", err)
	}
	return nil
}

// startSSE starts the server with SSE transport
func (s *Server) startSSE(ctx context.Context, cfg TransportConfig) error {
	addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
	s.logger.Info("Starting SSE transport", slog.String("address", addr))

	handler := mcp.NewSSEHandler(func(r *http.Request) *mcp.Server {
		return s.mcpServer
	}, nil)

	mux := http.NewServeMux()
	mux.Handle("/sse", handler)

	s.httpServer = &http.Server{
		Addr:    addr,
		Handler: mux,
	}

	s.logger.Info("MCP Hub server running", slog.String("url", fmt.Sprintf("http://%s/sse", addr)))

	if err := s.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
		return fmt.Errorf("SSE server failed: %w", err)
	}
	return nil
}

// Stop stops the MCP server
func (s *Server) Stop() error {
	s.logger.Info("Stopping MCP hub server")

	// Shutdown HTTP server if running
	if s.httpServer != nil {
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		if err := s.httpServer.Shutdown(ctx); err != nil {
			s.logger.Error("Error shutting down HTTP server", slog.String("error", err.Error()))
		}
	}

	// Disconnect from all remote servers
	if s.clientManager != nil {
		if err := s.clientManager.DisconnectAll(); err != nil {
			s.logger.Error("Error disconnecting from remote servers", slog.String("error", err.Error()))
			return err
		}
	}

	return nil
}

// registerBuiltinTools registers all built-in tools
func (s *Server) registerBuiltinTools() {
	// Schema for list tool
	// Register list tool (no parameters - returns all tools)
	s.builtinRegistry.RegisterTool(config.BuiltinTool{
		Name:        "list",
		Description: tools.ListDescription,
		InputSchema: map[string]any{
			"type":       "object",
			"properties": map[string]any{},
		},
	})

	// Register inspect tool
	s.builtinRegistry.RegisterTool(config.BuiltinTool{
		Name:        "inspect",
		Description: tools.InspectDescription,
		InputSchema: map[string]any{
			"type": "object",
			"properties": map[string]any{
				"name": map[string]any{
					"type":        "string",
					"description": "Namespaced tool name (serverID__toolName)",
					"maxLength":   500,
				},
			},
			"required": []string{"name"},
		},
	})

	// Register invoke tool
	s.builtinRegistry.RegisterTool(config.BuiltinTool{
		Name:        "invoke",
		Description: tools.InvokeDescription,
		InputSchema: map[string]any{
			"type": "object",
			"properties": map[string]any{
				"name": map[string]any{
					"type":        "string",
					"description": "Namespaced tool name (serverID__toolName)",
					"maxLength":   500,
				},
				"params": map[string]any{
					"type":        "object",
					"description": "Optional parameters to pass to the tool",
				},
			},
			"required": []string{"name"},
		},
	})

	// Register exec tool
	s.builtinRegistry.RegisterTool(config.BuiltinTool{
		Name:        "exec",
		Description: tools.ExecDescription,
		InputSchema: map[string]any{
			"type": "object",
			"properties": map[string]any{
				"code": map[string]any{
					"type":        "string",
					"minLength":   1,
					"description": "JavaScript to execute (async/await, timers, require for node:* built-ins). Use mcp.callTool() for MCP tools.",
				},
			},
			"required": []string{"code"},
		},
	})
}

// connectToRemoteServers connects to all configured remote MCP servers
func (s *Server) connectToRemoteServers() error {
	var errors []error

	for serverID, serverCfg := range s.config.MCPServers {
		// Skip disabled servers
		if !serverCfg.IsEnabled() {
			s.logger.Info("Skipping disabled server", slog.String("serverID", serverID))
			continue
		}

		s.logger.Info("Connecting to server", slog.String("serverID", serverID))
		if err := s.clientManager.ConnectToServer(serverID, serverCfg); err != nil {
			s.logger.Error("Failed to connect to server",
				slog.String("serverID", serverID),
				slog.String("error", err.Error()),
			)

			// If server is required, return error immediately
			if serverCfg.Required {
				return fmt.Errorf("required server %s failed to connect: %w", serverID, err)
			}

			errors = append(errors, fmt.Errorf("server %s: %w", serverID, err))
		}
	}

	if len(errors) > 0 {
		s.logger.Warn("Some optional servers failed to connect", slog.Int("count", len(errors)))
	}

	return nil
}

// registerAllTools registers all tools (built-in only) with the MCP server
func (s *Server) registerAllTools() error {
	// Register built-in tools
	for toolName, builtinTool := range s.builtinRegistry.GetAllTools() {
		if err := s.registerBuiltinToolHandler(toolName, builtinTool); err != nil {
			return fmt.Errorf("failed to register built-in tool %s: %w", toolName, err)
		}
	}

	s.logger.Info("Registered built-in tools",
		slog.Int("count", len(s.builtinRegistry.GetAllTools())),
	)

	return nil
}

// registerBuiltinToolHandler registers a handler for a built-in tool
func (s *Server) registerBuiltinToolHandler(toolName string, builtinTool config.BuiltinTool) error {
	// Create MCP tool schema
	mcpTool := &mcp.Tool{
		Name:        toolName,
		Description: builtinTool.Description,
		InputSchema: builtinTool.InputSchema,
	}

	// Register the tool with a handler that calls the appropriate built-in function
	handler := func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
		return s.handleBuiltinTool(ctx, toolName, req)
	}

	// Use Server.AddTool to register the tool
	s.mcpServer.AddTool(mcpTool, handler)

	s.logger.Debug("Registered built-in tool", slog.String("name", toolName))
	return nil
}

// handleBuiltinTool handles calls to built-in tools
func (s *Server) handleBuiltinTool(ctx context.Context, toolName string, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
	s.logger.Debug("Handling built-in tool call", slog.String("tool", toolName))

	// Apply timeout to prevent DoS attacks
	callCtx, cancel := context.WithTimeout(ctx, s.toolCallTimeout)
	defer cancel()

	// Create ToolProvider adapter for the client manager
	provider := tools.NewManagerAdapter(s.clientManager)

	switch toolName {
	case "list":
		return tools.HandleListTool(callCtx, provider, req)
	case "inspect":
		return tools.HandleInspectTool(callCtx, provider, req)
	case "invoke":
		return tools.HandleInvokeTool(callCtx, provider, req)
	case "exec":
		return tools.HandleExecuteTool(callCtx, s.logger, s.clientManager, req)
	default:
		return nil, fmt.Errorf("unknown built-in tool: %s", toolName)
	}
}
