"""
Graph orchestration for tb-py LangGraph
This file handles how nodes interact with each other and wraps node classes as LangGraph functions.
"""

from typing import Dict, Any, Callable
from langchain_core.messages import AIMessage
from handit_ai import tracing
import asyncio

from ...config import Config
from ...state import AgentState, set_stage_result, add_message, set_error, clear_error

# Import node classes from /src/nodes/
from ...nodes.llm.l1.processor import L1LLMNode
from ...nodes.llm.l_2.processor import L2LLMNode
from ...nodes.tools.t1.processor import T1ToolNode
from ...nodes.tools.t_2.processor import T2ToolNode


# Global node instances (initialized once)
_node_instances = {}

def _get_node_instance(node_name: str, node_type: str = "llm"):
    """
    Get or create a node instance (singleton pattern).
    
    Args:
        node_name: Name of the node
        node_type: Type of node ('llm' or 'tool')
        
    Returns:
        Node instance
    """
    global _node_instances
    key = f"{node_name}_{node_type}"
    
    if key not in _node_instances:
        config = Config()
        if node_type == "llm":
            # Dynamic import for LLM nodes
            class_name = f"{''.join(word.capitalize() for word in node_name.replace('-', '_').split('_'))}LLMNode"
            module_path = f"src.nodes.llm.{node_name.replace('-', '_')}.processor"
            module = __import__(module_path, fromlist=[class_name])
            node_class = getattr(module, class_name)
            _node_instances[key] = node_class(config)
        else:
            # Dynamic import for Tool nodes
            class_name = f"{''.join(word.capitalize() for word in node_name.replace('-', '_').split('_'))}ToolNode"
            module_path = f"src.nodes.tools.{node_name.replace('-', '_')}.processor"
            module = __import__(module_path, fromlist=[class_name])
            node_class = getattr(module, class_name)
            _node_instances[key] = node_class(config)
    
    return _node_instances[key]

# LangGraph node functions that wrap node classes

async def l1_node(state: AgentState) -> AgentState:
    """
    LangGraph wrapper for L1 LLM node.
    This function handles graph orchestration and calls the actual node class.
    
    Args:
        state: Current agent state
        
    Returns:
        Updated state with l1 results
    """
    try:
        # Get node instance (singleton)
        node_instance = _get_node_instance("l1", "llm")
        
        # Get input data for processing
        input_data = state.get("input", "")
        
        # Get previous stage results if available
        previous_results = ""
        if state.get("results"):
            stages = ["l1","l_2","t1","t_2"]
            current_index = stages.index("l1") if "l1" in stages else -1
            if current_index > 0:
                prev_stage = stages[current_index - 1]
                previous_results = state["results"].get(prev_stage, "")
        
        # Prepare input for the node class
        processing_input = previous_results if previous_results else str(input_data)
        
        # Call the actual node class
        result = await node_instance.run(processing_input)
        
        # Update state with results
        state = set_stage_result(state, "l1", result)
        
        # Add AI message
        message = AIMessage(content=f"L1: {result}")
        state = add_message(state, message)
        
        return state
        
    except Exception as e:
        error_msg = f"L1 node error: {str(e)}"
        print(f"❌ {error_msg}")
        return set_error(state, error_msg)

async def l_2_node(state: AgentState) -> AgentState:
    """
    LangGraph wrapper for L2 LLM node.
    This function handles graph orchestration and calls the actual node class.
    
    Args:
        state: Current agent state
        
    Returns:
        Updated state with l_2 results
    """
    try:
        # Get node instance (singleton)
        node_instance = _get_node_instance("l_2", "llm")
        
        # Get input data for processing
        input_data = state.get("input", "")
        
        # Get previous stage results if available
        previous_results = ""
        if state.get("results"):
            stages = ["l1","l_2","t1","t_2"]
            current_index = stages.index("l_2") if "l_2" in stages else -1
            if current_index > 0:
                prev_stage = stages[current_index - 1]
                previous_results = state["results"].get(prev_stage, "")
        
        # Prepare input for the node class
        processing_input = previous_results if previous_results else str(input_data)
        
        # Call the actual node class
        result = await node_instance.run(processing_input)
        
        # Update state with results
        state = set_stage_result(state, "l_2", result)
        
        # Add AI message
        message = AIMessage(content=f"L2: {result}")
        state = add_message(state, message)
        
        return state
        
    except Exception as e:
        error_msg = f"L2 node error: {str(e)}"
        print(f"❌ {error_msg}")
        return set_error(state, error_msg)

async def t1_tool_node(state: AgentState) -> AgentState:
    """
    LangGraph wrapper for T1 tool node.
    This function handles graph orchestration and calls the actual node class.
    
    Args:
        state: Current agent state
        
    Returns:
        Updated state with t1 tool results
    """
    try:
        # Get node instance (singleton)
        node_instance = _get_node_instance("t1", "tool")
        
        # Get input data for processing
        input_data = state.get("input", "")
        
        # Get previous stage results if available
        previous_results = ""
        if state.get("results"):
            stages = ["l1","l_2","t1","t_2"]
            current_index = stages.index("t1") if "t1" in stages else -1
            if current_index > 0:
                prev_stage = stages[current_index - 1]
                previous_results = state["results"].get(prev_stage, "")
        
        # Prepare input for the node class
        processing_input = previous_results if previous_results else str(input_data)
        
        # Call the actual node class
        result = await node_instance.run(processing_input)
        
        # Update state with results
        state = set_stage_result(state, "t1", result)
        
        # Add AI message
        message = AIMessage(content=f"T1 Tool: {result}")
        state = add_message(state, message)
        
        return state
        
    except Exception as e:
        error_msg = f"T1 tool node error: {str(e)}"
        print(f"❌ {error_msg}")
        return set_error(state, error_msg)

async def t_2_tool_node(state: AgentState) -> AgentState:
    """
    LangGraph wrapper for T2 tool node.
    This function handles graph orchestration and calls the actual node class.
    
    Args:
        state: Current agent state
        
    Returns:
        Updated state with t_2 tool results
    """
    try:
        # Get node instance (singleton)
        node_instance = _get_node_instance("t_2", "tool")
        
        # Get input data for processing
        input_data = state.get("input", "")
        
        # Get previous stage results if available
        previous_results = ""
        if state.get("results"):
            stages = ["l1","l_2","t1","t_2"]
            current_index = stages.index("t_2") if "t_2" in stages else -1
            if current_index > 0:
                prev_stage = stages[current_index - 1]
                previous_results = state["results"].get(prev_stage, "")
        
        # Prepare input for the node class
        processing_input = previous_results if previous_results else str(input_data)
        
        # Call the actual node class
        result = await node_instance.run(processing_input)
        
        # Update state with results
        state = set_stage_result(state, "t_2", result)
        
        # Add AI message
        message = AIMessage(content=f"T2 Tool: {result}")
        state = add_message(state, message)
        
        return state
        
    except Exception as e:
        error_msg = f"T2 tool node error: {str(e)}"
        print(f"❌ {error_msg}")
        return set_error(state, error_msg)


async def finalizer_node(state: AgentState) -> AgentState:
    """
    Finalizer node that merges results from parallel execution.
    
    This node consolidates all results from parallel branches and prepares
    the final output for the agent. It's automatically included in all agents
    to enable parallelization and data merging.
    
    Args:
        state: Current agent state with results from all parallel nodes
        
    Returns:
        Updated state with merged final results
    """
    try:
        from datetime import datetime
        from langchain_core.messages import AIMessage
        
        # Get all results from parallel execution
        all_results = state.get("results", {})
        
        # Merge all results into a final output
        final_output = {
            "merged_results": all_results,
            "execution_summary": {
                "total_stages": len(all_results),
                "completed_stages": list(all_results.keys()),
                "finalization_timestamp": datetime.now().isoformat()
            }
        }
        
        # Create new state without circular references
        updated_results = dict(state.get("results", {}))
        updated_results["finalizer"] = final_output
        
        updated_messages = list(state.get("messages", []))
        final_message = AIMessage(content=f"Finalizer: Merged results from {len(all_results)} parallel stages")
        updated_messages.append(final_message)
        
        # Return new state with proper structure
        return {
            "input": state.get("input"),
            "messages": updated_messages,
            "context": state.get("context", {}),
            "results": updated_results,
            "current_stage": "finalizer",
            "error": state.get("error"),
            "metadata": state.get("metadata", {})
        }
        
    except Exception as e:
        # Handle finalizer errors gracefully
        return {
            "input": state.get("input"),
            "messages": state.get("messages", []),
            "context": state.get("context", {}),
            "results": state.get("results", {}),
            "current_stage": "finalizer",
            "error": f"Finalizer error: {str(e)}",
            "metadata": state.get("metadata", {})
        }

def get_graph_nodes(config: Config) -> Dict[str, Callable]:
    """
    Get graph nodes based on configuration.
    Dynamically returns nodes based on user's LLM and tool node configuration.
    Always includes a finalizer node for parallel execution and data merging.
    
    Args:
        config: Configuration object
        
    Returns:
        Dictionary of node functions including finalizer
    """
    node_functions = {}
    
    # Add LLM nodes
    node_functions["l1"] = l1_node
    node_functions["l_2"] = l_2_node

    
    # Add Tool nodes  
    node_functions["t1"] = t1_tool_node
    node_functions["t_2"] = t_2_tool_node

    
    # Always add finalizer node for parallel execution and data merging
    node_functions["finalizer"] = finalizer_node
    
    # Return all nodes including finalizer
    return node_functions

def create_custom_node(stage_name: str, logic_func: Callable) -> Callable:
    """
    Create a custom node function.
    
    Args:
        stage_name: Name of the stage
        logic_func: Logic function for the node
        
    Returns:
        Node function
    """
    async def custom_node(state: AgentState) -> AgentState:
        try:
            clear_error(state)
            result = await logic_func(state)
            state = set_stage_result(state, stage_name, result)
            return state
        except Exception as e:
            return set_error(state, f"{stage_name} error: {str(e)}")
    
    return custom_node
