"""
Async chat completion module for MiniMax model interactions.

This module handles async streaming chat completions with the MiniMax-M1-80k model,
providing robust error handling, cancellation support, and integration with Textual's
worker system for non-blocking UI operations.
"""

import asyncio
import logging
import sys
from typing import Optional, Dict, Any, AsyncIterator, Callable, Union
from contextlib import asynccontextmanager

# Try to import async HuggingFace client
try:
    from huggingface_hub import AsyncInferenceClient
except ImportError:
    AsyncInferenceClient = None

# Import our custom exceptions and retry mechanisms
from .exceptions import (
    MiniMaxClientError, 
    NetworkError, 
    ModelError, 
    APIError, 
    RateLimitError,
    ServiceUnavailableError,
    TimeoutError as MiniMaxTimeoutError
)
from .retry import with_retry, NETWORK_RETRY_CONFIG

# Try to import HuggingFace exceptions
try:
    from huggingface_hub import (
        HfHubHTTPError,
        RepositoryNotFoundError,
        GatedRepoError,
        BadRequestError,
        InferenceTimeoutError
    )
except ImportError:
    # Fallback for older versions
    HfHubHTTPError = Exception
    RepositoryNotFoundError = Exception
    GatedRepoError = Exception
    BadRequestError = Exception
    InferenceTimeoutError = Exception

# Try to import aiohttp for better error handling
try:
    import aiohttp
except ImportError:
    aiohttp = None

# Set up logger
logger = logging.getLogger(__name__)


class AsyncChatCompletionError(MiniMaxClientError):
    """Custom exception for async chat completion errors."""
    pass


class AsyncChatCompletion:
    """
    Async chat completion handler that provides non-blocking streaming
    for Textual integration with proper cancellation support.
    """
    
    def __init__(self, client: Optional[AsyncInferenceClient] = None):
        """
        Initialize the async chat completion handler.
        
        Args:
            client: Optional pre-configured AsyncInferenceClient
        """
        self.client = client
        self._current_task: Optional[asyncio.Task] = None
        self._cancelled = False
        
    async def __aenter__(self):
        """Async context manager entry."""
        return self
        
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit with proper cleanup."""
        await self.cleanup()
        
    async def cleanup(self):
        """Clean up resources and cancel any running tasks."""
        self._cancelled = True
        if self._current_task and not self._current_task.done():
            self._current_task.cancel()
            try:
                await self._current_task
            except asyncio.CancelledError:
                pass
                
        if self.client:
            try:
                await self.client.close()
            except Exception as e:
                logger.warning(f"Error closing async client: {e}")
                
    def cancel(self):
        """Cancel the current operation."""
        self._cancelled = True
        if self._current_task and not self._current_task.done():
            self._current_task.cancel()


async def create_async_chat_completion(
    client: AsyncInferenceClient,
    model_name: str,
    user_message: str,
    stream: bool = True,
    progress_callback: Optional[Callable[[str], None]] = None,
    **kwargs: Any
) -> Optional[str]:
    """
    Create an async streaming or non-streaming chat completion request.
    
    Args:
        client: The configured async inference client
        model_name: Name of the model to use for completion
        user_message: The user's message content
        stream: Whether to use streaming response (default: True)
        progress_callback: Optional callback for streaming content updates
        **kwargs: Additional parameters for chat completion
        
    Returns:
        Complete response content if successful, None if error occurred
        
    Raises:
        AsyncChatCompletionError: For chat completion specific errors
        NetworkError: For network-related issues
        ModelError: For model-related issues
        APIError: For API-related issues
        asyncio.CancelledError: If the operation was cancelled
    """
    if AsyncInferenceClient is None:
        raise AsyncChatCompletionError(
            "AsyncInferenceClient not available. Please install huggingface_hub with async support: pip install 'huggingface_hub[inference]'",
            error_code="ASYNC_CLIENT_UNAVAILABLE"
        )
    
    # Prepare the chat message
    messages = [
        {
            "role": "user",
            "content": user_message
        }
    ]
    
    logger.info(f"Sending async request to model: {model_name}")
    logger.info(f"User message: {user_message}")
    
    try:
        # Create chat completion request with retry logic
        if stream:
            logger.info("Starting async streaming response...")
            return await _create_streaming_completion_with_retry(
                client, model_name, messages, progress_callback, **kwargs
            )
        else:
            return await _create_non_streaming_completion_with_retry(
                client, model_name, messages, **kwargs
            )
            
    except asyncio.CancelledError:
        logger.info("Async chat completion was cancelled")
        raise
        
    except InferenceTimeoutError as e:
        logger.error(f"Model inference timeout: {e}")
        raise MiniMaxTimeoutError(
            "Model inference timeout - the model may be unavailable or overloaded",
            error_code="INFERENCE_TIMEOUT",
            details={"model_name": model_name, "user_message_length": len(user_message)}
        )
        
    except HfHubHTTPError as e:
        await _handle_http_error(e, model_name)
        
    except RepositoryNotFoundError as e:
        logger.error(f"Repository not found: {e}")
        raise ModelError(
            f"The model '{model_name}' does not exist or is not accessible",
            error_code="REPOSITORY_NOT_FOUND",
            details={"model_name": model_name}
        )
        
    except GatedRepoError as e:
        logger.error(f"Access denied to gated repository: {e}")
        raise ModelError(
            f"The model '{model_name}' requires special access - please request access through Hugging Face Hub",
            error_code="GATED_REPOSITORY",
            details={"model_name": model_name}
        )
        
    except BadRequestError as e:
        logger.error(f"Bad request: {e}")
        raise APIError(
            "Bad request - invalid parameters or configuration",
            error_code="BAD_REQUEST",
            details={"message": str(e)}
        )
        
    except Exception as e:
        if aiohttp and isinstance(e, (aiohttp.ClientError, aiohttp.ServerTimeoutError)):
            logger.error(f"Network connection error: {e}")
            raise NetworkError(
                "Network connection failed",
                error_code="CONNECTION_ERROR",
                details={"error_type": type(e).__name__}
            )
        
        logger.error(f"Unexpected error during async chat completion: {e}")
        raise AsyncChatCompletionError(
            f"Unexpected error during async chat completion: {str(e)}",
            error_code="UNEXPECTED_ERROR",
            details={"error_type": type(e).__name__, "model_name": model_name}
        )


async def _create_streaming_completion_with_retry(
    client: AsyncInferenceClient,
    model_name: str,
    messages: list,
    progress_callback: Optional[Callable[[str], None]],
    max_retries: int = 3,
    **kwargs: Any
) -> str:
    """Create streaming completion with retry logic."""
    for attempt in range(max_retries + 1):
        try:
            response = await client.chat_completion(
                model=model_name,
                messages=messages,
                stream=True,
                **kwargs
            )
            return await process_async_streaming_response(response, progress_callback)
            
        except (NetworkError, ServiceUnavailableError) as e:
            if attempt < max_retries:
                wait_time = 2 ** attempt
                logger.warning(f"Attempt {attempt + 1} failed, retrying in {wait_time}s: {e}")
                await asyncio.sleep(wait_time)
                continue
            raise
            
        except Exception as e:
            # Don't retry for non-network errors
            raise


async def _create_non_streaming_completion_with_retry(
    client: AsyncInferenceClient,
    model_name: str,
    messages: list,
    max_retries: int = 3,
    **kwargs: Any
) -> str:
    """Create non-streaming completion with retry logic."""
    for attempt in range(max_retries + 1):
        try:
            response = await client.chat_completion(
                model=model_name,
                messages=messages,
                stream=False,
                **kwargs
            )
            return await _process_async_non_streaming_response(response)
            
        except (NetworkError, ServiceUnavailableError) as e:
            if attempt < max_retries:
                wait_time = 2 ** attempt
                logger.warning(f"Attempt {attempt + 1} failed, retrying in {wait_time}s: {e}")
                await asyncio.sleep(wait_time)
                continue
            raise
            
        except Exception as e:
            # Don't retry for non-network errors
            raise


async def process_async_streaming_response(
    stream: AsyncIterator[Any],
    progress_callback: Optional[Callable[[str], None]] = None
) -> str:
    """
    Process async streaming response from the model with enhanced error recovery.
    
    Args:
        stream: Async iterator of response chunks
        progress_callback: Optional callback for real-time content updates
        
    Returns:
        Complete response content
        
    Raises:
        asyncio.CancelledError: If the operation was cancelled
        AsyncChatCompletionError: If processing fails
    """
    response_content = ""
    chunk_count = 0
    
    try:
        async for chunk in stream:
            # Check for cancellation
            if asyncio.current_task().cancelled():
                raise asyncio.CancelledError()
                
            chunk_count += 1
            
            # Extract content from the chunk with error recovery
            try:
                content = _extract_content_from_chunk(chunk)
                if content:
                    response_content += content
                    
                    # Call progress callback for real-time updates
                    if progress_callback:
                        try:
                            progress_callback(content)
                        except Exception as e:
                            logger.warning(f"Error in progress callback: {e}")
                    
            except Exception as e:
                logger.warning(f"Error processing chunk {chunk_count}: {e}")
                # Continue processing other chunks
                continue
        
        logger.info("Async streaming completed successfully")
        logger.debug(f"Processed {chunk_count} chunks")
        
        if response_content.strip():
            logger.debug(f"Complete response: {response_content.strip()}")
        else:
            logger.warning("No content received in async streaming response")
            
        return response_content.strip()
        
    except asyncio.CancelledError:
        logger.info("Async streaming was cancelled")
        # Return partial content if available
        if response_content:
            logger.info(f"Returning partial response: {len(response_content)} characters")
        raise
        
    except Exception as e:
        logger.error(f"Error during async streaming response processing: {e}")
        if response_content:
            logger.info(f"Partial response received: {len(response_content)} characters")
            return response_content.strip()
        raise AsyncChatCompletionError(f"Failed to process async streaming response: {e}")


async def _process_async_non_streaming_response(response: Any) -> str:
    """
    Process async non-streaming response from the model.
    
    Args:
        response: Response object from the model
        
    Returns:
        Response content
    """
    try:
        # Extract content from non-streaming response
        if hasattr(response, 'choices') and response.choices:
            choice = response.choices[0]
            if hasattr(choice, 'message') and hasattr(choice.message, 'content'):
                content = choice.message.content
                logger.info("Async non-streaming response received successfully")
                logger.debug(f"Response content: {content}")
                return content.strip() if content else ""
        
        logger.warning("No content found in async non-streaming response")
        return ""
        
    except Exception as e:
        logger.error(f"Error processing async non-streaming response: {e}")
        raise AsyncChatCompletionError(f"Failed to process async non-streaming response: {e}")


def _extract_content_from_chunk(chunk: Any) -> Optional[str]:
    """
    Extract content from a streaming response chunk with robust error handling.
    
    Args:
        chunk: Individual chunk from streaming response
        
    Returns:
        Extracted content or None if no content found
    """
    try:
        # Handle different chunk formats
        if hasattr(chunk, 'choices') and chunk.choices:
            delta = chunk.choices[0].delta
            if hasattr(delta, 'content') and delta.content:
                return delta.content
                
        # Alternative chunk format
        if hasattr(chunk, 'delta') and hasattr(chunk.delta, 'content'):
            return chunk.delta.content
            
        # Direct content access
        if hasattr(chunk, 'content'):
            return chunk.content
            
        return None
        
    except (AttributeError, IndexError, KeyError) as e:
        logger.debug(f"Could not extract content from chunk: {e}")
        return None


async def _handle_http_error(e: HfHubHTTPError, model_name: str):
    """Handle HTTP errors from HuggingFace Hub."""
    if hasattr(e, 'response') and e.response is not None:
        status_code = e.response.status_code
        if status_code == 429:
            logger.error(f"Rate limit exceeded: {e}")
            retry_after = None
            if hasattr(e.response, 'headers'):
                retry_after = e.response.headers.get('Retry-After')
                if retry_after:
                    try:
                        retry_after = int(retry_after)
                    except ValueError:
                        retry_after = None
            raise RateLimitError(
                "Rate limit exceeded - please wait before making another request",
                retry_after=retry_after,
                error_code="RATE_LIMIT_EXCEEDED",
                details={"status_code": status_code}
            )
        elif status_code == 404:
            logger.error(f"Model not found: {e}")
            raise ModelError(
                f"Model '{model_name}' not found or not accessible",
                error_code="MODEL_NOT_FOUND",
                details={"model_name": model_name}
            )
        elif status_code == 401:
            logger.error(f"Authentication failed: {e}")
            raise APIError(
                "Authentication failed - please check your HF_TOKEN",
                error_code="AUTHENTICATION_FAILED",
                details={"status_code": status_code}
            )
        elif status_code >= 500:
            logger.error(f"Server error {status_code}: {e}")
            raise ServiceUnavailableError(
                f"Server error {status_code} - service temporarily unavailable",
                error_code="SERVER_ERROR",
                details={"status_code": status_code}
            )
        else:
            logger.error(f"HTTP error {status_code}: {e}")
            raise APIError(
                f"HTTP error {status_code}",
                error_code="HTTP_ERROR",
                details={"status_code": status_code}
            )
    else:
        logger.error(f"HTTP error: {e}")
        raise NetworkError(
            "Network error occurred during API request",
            error_code="NETWORK_ERROR"
        )


async def validate_async_chat_parameters(
    model_name: str,
    user_message: str,
    **kwargs: Any
) -> None:
    """
    Validate async chat completion parameters.
    
    Args:
        model_name: Name of the model
        user_message: User's message content
        **kwargs: Additional parameters
        
    Raises:
        ValueError: If parameters are invalid
    """
    if not model_name or not isinstance(model_name, str):
        raise ValueError("Model name must be a non-empty string")
        
    if not user_message or not isinstance(user_message, str):
        raise ValueError("User message must be a non-empty string")
        
    if len(user_message.strip()) == 0:
        raise ValueError("User message cannot be empty or whitespace only")
        
    # Validate additional parameters
    if 'max_tokens' in kwargs:
        max_tokens = kwargs['max_tokens']
        if not isinstance(max_tokens, int) or max_tokens <= 0:
            raise ValueError("max_tokens must be a positive integer")
            
    if 'temperature' in kwargs:
        temperature = kwargs['temperature']
        if not isinstance(temperature, (int, float)) or temperature < 0:
            raise ValueError("temperature must be a non-negative number")


@asynccontextmanager
async def create_async_client(api_key: str, provider: Optional[str] = None):
    """
    Create an async inference client with proper resource management.
    
    Args:
        api_key: The Hugging Face API token
        provider: Optional provider specification
        
    Yields:
        AsyncInferenceClient: Configured async inference client
        
    Raises:
        AsyncChatCompletionError: If client creation fails
    """
    if AsyncInferenceClient is None:
        raise AsyncChatCompletionError(
            "AsyncInferenceClient not available. Please install huggingface_hub with async support: pip install 'huggingface_hub[inference]'",
            error_code="ASYNC_CLIENT_UNAVAILABLE"
        )
    
    client = None
    try:
        logger.info("Creating AsyncInferenceClient...")
        client = AsyncInferenceClient(
            provider=provider,
            token=api_key
        )
        logger.info("✓ AsyncInferenceClient created successfully")
        yield client
        
    except Exception as e:
        logger.error(f"Failed to create AsyncInferenceClient: {e}")
        raise AsyncChatCompletionError(
            f"Failed to create async client: {str(e)}",
            error_code="CLIENT_CREATION_FAILED"
        )
    finally:
        if client:
            try:
                await client.close()
                logger.info("✓ AsyncInferenceClient closed successfully")
            except Exception as e:
                logger.warning(f"Error closing AsyncInferenceClient: {e}")


# Textual worker integration helpers
class TextualWorkerMixin:
    """
    Mixin class to help integrate async chat completion with Textual workers.
    """
    
    @staticmethod
    async def run_chat_completion_worker(
        client: AsyncInferenceClient,
        model_name: str,
        user_message: str,
        progress_callback: Optional[Callable[[str], None]] = None,
        **kwargs: Any
    ) -> str:
        """
        Run chat completion in a Textual worker context.
        
        This method is designed to be called from within a Textual worker
        to provide non-blocking AI operations in the TUI.
        
        Args:
            client: Async inference client
            model_name: Model to use
            user_message: User's message
            progress_callback: Callback for streaming updates
            **kwargs: Additional chat parameters
            
        Returns:
            Complete response content
        """
        try:
            return await create_async_chat_completion(
                client=client,
                model_name=model_name,
                user_message=user_message,
                stream=True,
                progress_callback=progress_callback,
                **kwargs
            )
        except asyncio.CancelledError:
            logger.info("Chat completion worker was cancelled")
            raise
        except Exception as e:
            logger.error(f"Error in chat completion worker: {e}")
            raise