"""
Retry mechanisms for the MiniMax client.

This module provides decorators and utilities for implementing retry logic
with exponential backoff and jitter for robustness against transient failures.
"""

import time
import random
import logging
from typing import Callable, Type, Tuple, Optional, Any
from functools import wraps

from .exceptions import RetryableError, RateLimitError, ServiceUnavailableError, TimeoutError


logger = logging.getLogger(__name__)


def exponential_backoff_with_jitter(
    retry_attempts: int = 3,
    base_delay: float = 1.0,
    max_delay: float = 60.0,
    backoff_factor: float = 2.0,
    jitter: bool = True,
    retryable_exceptions: Tuple[Type[Exception], ...] = (RetryableError,)
) -> Callable:
    """
    Decorator that implements exponential backoff with optional jitter for retry logic.
    
    Args:
        retry_attempts: Maximum number of retry attempts
        base_delay: Initial delay between retries in seconds
        max_delay: Maximum delay between retries in seconds
        backoff_factor: Factor by which delay increases each retry
        jitter: Whether to add random jitter to prevent thundering herd
        retryable_exceptions: Tuple of exception types that should trigger retries
    
    Returns:
        Decorated function with retry logic
    """
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs) -> Any:
            last_exception = None
            
            for attempt in range(retry_attempts + 1):
                try:
                    return func(*args, **kwargs)
                except retryable_exceptions as e:
                    last_exception = e
                    
                    if attempt == retry_attempts:
                        # Last attempt failed, re-raise the exception
                        logger.error(f"Function {func.__name__} failed after {retry_attempts} retries: {e}")
                        raise
                    
                    # Calculate delay with exponential backoff
                    delay = min(base_delay * (backoff_factor ** attempt), max_delay)
                    
                    # Add jitter if enabled
                    if jitter:
                        delay *= (0.5 + random.random() * 0.5)  # Jitter between 50-100% of calculated delay
                    
                    # Special handling for rate limit errors
                    if isinstance(e, RateLimitError) and hasattr(e, 'retry_after') and e.retry_after:
                        delay = max(delay, e.retry_after)
                    
                    logger.warning(
                        f"Function {func.__name__} failed on attempt {attempt + 1}/{retry_attempts + 1}: {e}. "
                        f"Retrying in {delay:.2f} seconds..."
                    )
                    
                    time.sleep(delay)
                except Exception as e:
                    # Non-retryable exception, re-raise immediately
                    logger.error(f"Function {func.__name__} failed with non-retryable error: {e}")
                    raise
            
            # This should never be reached, but just in case
            if last_exception:
                raise last_exception
        
        return wrapper
    return decorator


def retry_on_specific_errors(
    exceptions: Tuple[Type[Exception], ...],
    max_attempts: int = 3,
    delay: float = 1.0
) -> Callable:
    """
    Simple retry decorator for specific exception types.
    
    Args:
        exceptions: Tuple of exception types to retry on
        max_attempts: Maximum number of attempts (including initial)
        delay: Fixed delay between retries in seconds
    
    Returns:
        Decorated function with retry logic
    """
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs) -> Any:
            last_exception = None
            
            for attempt in range(max_attempts):
                try:
                    return func(*args, **kwargs)
                except exceptions as e:
                    last_exception = e
                    
                    if attempt == max_attempts - 1:
                        # Last attempt, re-raise
                        raise
                    
                    logger.warning(
                        f"Function {func.__name__} failed on attempt {attempt + 1}/{max_attempts}: {e}. "
                        f"Retrying in {delay} seconds..."
                    )
                    
                    time.sleep(delay)
            
            # This should never be reached
            if last_exception:
                raise last_exception
        
        return wrapper
    return decorator


class RetryConfig:
    """Configuration class for retry behavior."""
    
    def __init__(
        self,
        max_attempts: int = 3,
        base_delay: float = 1.0,
        max_delay: float = 60.0,
        backoff_factor: float = 2.0,
        jitter: bool = True,
        retryable_exceptions: Optional[Tuple[Type[Exception], ...]] = None
    ):
        self.max_attempts = max_attempts
        self.base_delay = base_delay
        self.max_delay = max_delay
        self.backoff_factor = backoff_factor
        self.jitter = jitter
        self.retryable_exceptions = retryable_exceptions or (
            RetryableError,
            RateLimitError,
            ServiceUnavailableError,
            TimeoutError
        )
    
    def create_decorator(self) -> Callable:
        """Create a retry decorator with this configuration."""
        return exponential_backoff_with_jitter(
            retry_attempts=self.max_attempts - 1,  # -1 because we count initial attempt
            base_delay=self.base_delay,
            max_delay=self.max_delay,
            backoff_factor=self.backoff_factor,
            jitter=self.jitter,
            retryable_exceptions=self.retryable_exceptions
        )


# Default retry configurations for common scenarios
DEFAULT_RETRY_CONFIG = RetryConfig()

AGGRESSIVE_RETRY_CONFIG = RetryConfig(
    max_attempts=5,
    base_delay=0.5,
    max_delay=30.0,
    backoff_factor=1.5
)

CONSERVATIVE_RETRY_CONFIG = RetryConfig(
    max_attempts=2,
    base_delay=2.0,
    max_delay=10.0,
    backoff_factor=2.0
)

NETWORK_RETRY_CONFIG = RetryConfig(
    max_attempts=4,
    base_delay=1.0,
    max_delay=45.0,
    backoff_factor=2.0,
    retryable_exceptions=(
        RetryableError,
        RateLimitError,
        ServiceUnavailableError,
        TimeoutError,
        ConnectionError
    )
)


def with_retry(config: RetryConfig = DEFAULT_RETRY_CONFIG) -> Callable:
    """
    Decorator factory that creates a retry decorator with the given configuration.
    
    Args:
        config: RetryConfig instance defining retry behavior
    
    Returns:
        Retry decorator
    """
    return config.create_decorator()