"""
Chat completion module for MiniMax model interactions.

This module handles streaming chat completions with the MiniMax-M1-80k model,
providing robust error handling and configurable parameters.
"""

import logging
import sys
import requests
from typing import Optional, Dict, Any, Iterator, Generator

from huggingface_hub import InferenceClient

# Import our custom exceptions and retry mechanisms
from .exceptions import (
    MiniMaxClientError, 
    NetworkError, 
    ModelError, 
    APIError, 
    RateLimitError,
    ServiceUnavailableError,
    TimeoutError as MiniMaxTimeoutError
)
from .retry import with_retry, NETWORK_RETRY_CONFIG

# Try to import HuggingFace exceptions
try:
    from huggingface_hub import (
        HfHubHTTPError,
        RepositoryNotFoundError,
        GatedRepoError,
        BadRequestError,
        InferenceTimeoutError
    )
except ImportError:
    # Fallback for older versions
    HfHubHTTPError = Exception
    RepositoryNotFoundError = Exception
    GatedRepoError = Exception
    BadRequestError = Exception
    InferenceTimeoutError = Exception

# Set up logger
logger = logging.getLogger(__name__)


class ChatCompletionError(MiniMaxClientError):
    """Custom exception for chat completion errors."""
    pass


@with_retry(NETWORK_RETRY_CONFIG)
def create_chat_completion(
    client: InferenceClient,
    model_name: str,
    user_message: str,
    stream: bool = True,
    **kwargs: Any
) -> Optional[str]:
    """
    Create a streaming or non-streaming chat completion request to the MiniMax model.
    
    Args:
        client: The configured inference client
        model_name: Name of the model to use for completion
        user_message: The user's message content
        stream: Whether to use streaming response (default: True)
        **kwargs: Additional parameters for chat completion
        
    Returns:
        Complete response content if successful, None if error occurred
        
    Raises:
        ChatCompletionError: For chat completion specific errors
        NetworkError: For network-related issues
        ModelError: For model-related issues
        APIError: For API-related issues
    """
    # Prepare the chat message
    messages = [
        {
            "role": "user",
            "content": user_message
        }
    ]
    
    logger.info(f"Sending request to model: {model_name}")
    logger.info(f"User message: {user_message}")
    
    if stream:
        logger.info("Starting streaming response...")
        print("-" * 50)
    
    try:
        # Create chat completion request
        if stream:
            response = client.chat_completion(
                model=model_name,
                messages=messages,
                stream=True,
                **kwargs
            )
            return _process_streaming_response(response)
        else:
            response = client.chat_completion(
                model=model_name,
                messages=messages,
                stream=False,
                **kwargs
            )
            return _process_non_streaming_response(response)
            
    except InferenceTimeoutError as e:
        logger.error(f"Model inference timeout: {e}")
        raise MiniMaxTimeoutError(
            "Model inference timeout - the model may be unavailable or overloaded",
            error_code="INFERENCE_TIMEOUT",
            details={"model_name": model_name, "user_message_length": len(user_message)}
        )
        
    except HfHubHTTPError as e:
        if hasattr(e, 'response') and e.response is not None:
            status_code = e.response.status_code
            if status_code == 429:
                logger.error(f"Rate limit exceeded: {e}")
                retry_after = None
                if hasattr(e.response, 'headers'):
                    retry_after = e.response.headers.get('Retry-After')
                    if retry_after:
                        try:
                            retry_after = int(retry_after)
                        except ValueError:
                            retry_after = None
                raise RateLimitError(
                    "Rate limit exceeded - please wait before making another request",
                    retry_after=retry_after,
                    error_code="RATE_LIMIT_EXCEEDED",
                    details={"status_code": status_code}
                )
            elif status_code == 404:
                logger.error(f"Model not found: {e}")
                raise ModelError(
                    f"Model '{model_name}' not found or not accessible",
                    error_code="MODEL_NOT_FOUND",
                    details={"model_name": model_name}
                )
            elif status_code == 401:
                logger.error(f"Authentication failed: {e}")
                raise APIError(
                    "Authentication failed - please check your HF_TOKEN",
                    error_code="AUTHENTICATION_FAILED",
                    details={"status_code": status_code}
                )
            elif status_code >= 500:
                logger.error(f"Server error {status_code}: {e}")
                raise ServiceUnavailableError(
                    f"Server error {status_code} - service temporarily unavailable",
                    error_code="SERVER_ERROR",
                    details={"status_code": status_code}
                )
            else:
                logger.error(f"HTTP error {status_code}: {e}")
                raise APIError(
                    f"HTTP error {status_code}",
                    error_code="HTTP_ERROR",
                    details={"status_code": status_code}
                )
        else:
            logger.error(f"HTTP error: {e}")
            raise NetworkError(
                "Network error occurred during API request",
                error_code="NETWORK_ERROR"
            )
        
    except RepositoryNotFoundError as e:
        logger.error(f"Repository not found: {e}")
        raise ModelError(
            f"The model '{model_name}' does not exist or is not accessible",
            error_code="REPOSITORY_NOT_FOUND",
            details={"model_name": model_name}
        )
        
    except GatedRepoError as e:
        logger.error(f"Access denied to gated repository: {e}")
        raise ModelError(
            f"The model '{model_name}' requires special access - please request access through Hugging Face Hub",
            error_code="GATED_REPOSITORY",
            details={"model_name": model_name}
        )
        
    except BadRequestError as e:
        logger.error(f"Bad request: {e}")
        raise APIError(
            "Bad request - invalid parameters or configuration",
            error_code="BAD_REQUEST",
            details={"message": str(e)}
        )
        
    except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
        logger.error(f"Network connection error: {e}")
        raise NetworkError(
            "Network connection failed",
            error_code="CONNECTION_ERROR",
            details={"error_type": type(e).__name__}
        )
        
    except Exception as e:
        logger.error(f"Unexpected error during chat completion: {e}")
        raise ChatCompletionError(
            f"Unexpected error during chat completion: {str(e)}",
            error_code="UNEXPECTED_ERROR",
            details={"error_type": type(e).__name__, "model_name": model_name}
        )


def _process_streaming_response(stream: Iterator[Any]) -> str:
    """
    Process streaming response from the model with enhanced error recovery.
    
    Args:
        stream: Iterator of response chunks
        
    Returns:
        Complete response content
    """
    response_content = ""
    chunk_count = 0
    
    try:
        for chunk in stream:
            chunk_count += 1
            
            # Extract content from the chunk with error recovery
            try:
                content = _extract_content_from_chunk(chunk)
                if content:
                    # Print content without newlines for streaming effect
                    print(content, end='', flush=True)
                    response_content += content
                    
            except Exception as e:
                logger.warning(f"Error processing chunk {chunk_count}: {e}")
                # Continue processing other chunks
                continue
        
        print("\n" + "-" * 50)
        logger.info("Streaming completed successfully")
        logger.debug(f"Processed {chunk_count} chunks")
        
        # Log the complete response for debugging
        if response_content.strip():
            logger.debug(f"Complete response: {response_content.strip()}")
            print(f"\n📄 Complete response:\n{response_content.strip()}")
        else:
            logger.warning("No content received in streaming response")
            
        return response_content.strip()
        
    except Exception as e:
        logger.error(f"Error during streaming response processing: {e}")
        if response_content:
            logger.info(f"Partial response received: {len(response_content)} characters")
            return response_content.strip()
        raise ChatCompletionError(f"Failed to process streaming response: {e}")


def _process_non_streaming_response(response: Any) -> str:
    """
    Process non-streaming response from the model.
    
    Args:
        response: Response object from the model
        
    Returns:
        Response content
    """
    try:
        # Extract content from non-streaming response
        if hasattr(response, 'choices') and response.choices:
            choice = response.choices[0]
            if hasattr(choice, 'message') and hasattr(choice.message, 'content'):
                content = choice.message.content
                logger.info("Non-streaming response received successfully")
                logger.debug(f"Response content: {content}")
                print(f"📄 Response:\n{content}")
                return content.strip() if content else ""
        
        logger.warning("No content found in non-streaming response")
        return ""
        
    except Exception as e:
        logger.error(f"Error processing non-streaming response: {e}")
        raise ChatCompletionError(f"Failed to process non-streaming response: {e}")


def _extract_content_from_chunk(chunk: Any) -> Optional[str]:
    """
    Extract content from a streaming response chunk with robust error handling.
    
    Args:
        chunk: Individual chunk from streaming response
        
    Returns:
        Extracted content or None if no content found
    """
    try:
        # Handle different chunk formats
        if hasattr(chunk, 'choices') and chunk.choices:
            delta = chunk.choices[0].delta
            if hasattr(delta, 'content') and delta.content:
                return delta.content
                
        # Alternative chunk format
        if hasattr(chunk, 'delta') and hasattr(chunk.delta, 'content'):
            return chunk.delta.content
            
        # Direct content access
        if hasattr(chunk, 'content'):
            return chunk.content
            
        return None
        
    except (AttributeError, IndexError, KeyError) as e:
        logger.debug(f"Could not extract content from chunk: {e}")
        return None


def validate_chat_parameters(
    model_name: str,
    user_message: str,
    **kwargs: Any
) -> None:
    """
    Validate chat completion parameters.
    
    Args:
        model_name: Name of the model
        user_message: User's message content
        **kwargs: Additional parameters
        
    Raises:
        ValueError: If parameters are invalid
    """
    if not model_name or not isinstance(model_name, str):
        raise ValueError("Model name must be a non-empty string")
        
    if not user_message or not isinstance(user_message, str):
        raise ValueError("User message must be a non-empty string")
        
    if len(user_message.strip()) == 0:
        raise ValueError("User message cannot be empty or whitespace only")
        
    # Validate additional parameters
    if 'max_tokens' in kwargs:
        max_tokens = kwargs['max_tokens']
        if not isinstance(max_tokens, int) or max_tokens <= 0:
            raise ValueError("max_tokens must be a positive integer")
            
    if 'temperature' in kwargs:
        temperature = kwargs['temperature']
        if not isinstance(temperature, (int, float)) or temperature < 0:
            raise ValueError("temperature must be a non-negative number")