"""
Caching module for MiniMax Client

Provides AI response caching, configuration caching, and cache management
for improved performance and reduced API calls.
"""

import os
import json
import hashlib
import time
import logging
from pathlib import Path
from typing import Optional, Dict, Any, List, Union
from dataclasses import dataclass, asdict


logger = logging.getLogger(__name__)


@dataclass
class CacheEntry:
    """Represents a cache entry with metadata."""
    key: str
    data: Any
    created_at: float
    expires_at: Optional[float]
    access_count: int = 0
    last_accessed: Optional[float] = None
    metadata: Optional[Dict[str, Any]] = None


class CacheManager:
    """Advanced cache manager with TTL, size limits, and statistics."""
    
    def __init__(self, cache_dir: Optional[Union[str, Path]] = None,
                 default_ttl: int = 3600,  # 1 hour
                 max_size: int = 100 * 1024 * 1024,  # 100MB
                 max_entries: int = 10000):
        """
        Initialize cache manager.
        
        Args:
            cache_dir: Directory to store cache files
            default_ttl: Default time-to-live in seconds
            max_size: Maximum cache size in bytes
            max_entries: Maximum number of cache entries
        """
        if cache_dir is None:
            cache_dir = Path.home() / '.minimax' / 'cache'
        
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True, mode=0o700)
        
        self.default_ttl = default_ttl
        self.max_size = max_size
        self.max_entries = max_entries
        
        self.ai_cache_file = self.cache_dir / 'ai_responses.json'
        self.config_cache_file = self.cache_dir / 'config_cache.json'
        self.stats_file = self.cache_dir / 'cache_stats.json'
        
        self._memory_cache: Dict[str, CacheEntry] = {}
        self._load_persistent_cache()
        
    def _generate_cache_key(self, data: Any) -> str:
        """Generate a unique cache key from data."""
        if isinstance(data, dict):
            # Sort dict for consistent key generation
            sorted_data = json.dumps(data, sort_keys=True, separators=(',', ':'))
        else:
            sorted_data = str(data)
            
        return hashlib.sha256(sorted_data.encode()).hexdigest()[:32]
        
    def _load_persistent_cache(self) -> None:
        """Load cache from persistent storage."""
        try:
            if self.ai_cache_file.exists():
                with open(self.ai_cache_file, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                    
                for key, entry_data in data.items():
                    try:
                        entry = CacheEntry(**entry_data)
                        if not self._is_expired(entry):
                            self._memory_cache[key] = entry
                    except (TypeError, ValueError) as e:
                        logger.warning(f"Invalid cache entry {key}: {e}")
                        
        except (json.JSONDecodeError, OSError) as e:
            logger.warning(f"Could not load cache: {e}")
            
    def _save_persistent_cache(self) -> None:
        """Save cache to persistent storage."""
        try:
            # Clean expired entries before saving
            self._cleanup_expired()
            
            # Convert to serializable format
            cache_data = {}
            for key, entry in self._memory_cache.items():
                cache_data[key] = asdict(entry)
                
            # Write to temporary file first, then rename (atomic operation)
            temp_file = self.ai_cache_file.with_suffix('.tmp')
            with open(temp_file, 'w', encoding='utf-8') as f:
                json.dump(cache_data, f, indent=2)
                
            temp_file.replace(self.ai_cache_file)
            
        except (OSError, json.JSONEncodeError) as e:
            logger.error(f"Could not save cache: {e}")
            
    def _is_expired(self, entry: CacheEntry) -> bool:
        """Check if a cache entry is expired."""
        if entry.expires_at is None:
            return False
        return time.time() > entry.expires_at
        
    def _cleanup_expired(self) -> int:
        """Remove expired entries from cache."""
        expired_keys = [
            key for key, entry in self._memory_cache.items()
            if self._is_expired(entry)
        ]
        
        for key in expired_keys:
            del self._memory_cache[key]
            
        return len(expired_keys)
        
    def _enforce_size_limits(self) -> None:
        """Enforce cache size and entry count limits."""
        # Remove entries if we exceed max count
        if len(self._memory_cache) > self.max_entries:
            # Remove least recently accessed entries
            sorted_entries = sorted(
                self._memory_cache.items(),
                key=lambda x: x[1].last_accessed or 0
            )
            
            excess_count = len(self._memory_cache) - self.max_entries
            for key, _ in sorted_entries[:excess_count]:
                del self._memory_cache[key]
                
        # Check total cache size
        try:
            total_size = sum(
                len(json.dumps(asdict(entry)).encode())
                for entry in self._memory_cache.values()
            )
            
            if total_size > self.max_size:
                # Remove entries until we're under the size limit
                sorted_entries = sorted(
                    self._memory_cache.items(),
                    key=lambda x: x[1].last_accessed or 0
                )
                
                for key, entry in sorted_entries:
                    del self._memory_cache[key]
                    total_size -= len(json.dumps(asdict(entry)).encode())
                    if total_size <= self.max_size * 0.8:  # Leave some headroom
                        break
                        
        except (TypeError, ValueError) as e:
            logger.warning(f"Error checking cache size: {e}")
            
    def get(self, key: str) -> Optional[Any]:
        """
        Get a value from cache.
        
        Args:
            key: Cache key
            
        Returns:
            Cached value or None if not found/expired
        """
        if key not in self._memory_cache:
            return None
            
        entry = self._memory_cache[key]
        
        if self._is_expired(entry):
            del self._memory_cache[key]
            return None
            
        # Update access statistics
        entry.access_count += 1
        entry.last_accessed = time.time()
        
        return entry.data
        
    def set(self, key: str, value: Any, ttl: Optional[int] = None,
            metadata: Optional[Dict[str, Any]] = None) -> bool:
        """
        Set a value in cache.
        
        Args:
            key: Cache key
            value: Value to cache
            ttl: Time-to-live in seconds (None for no expiration)
            metadata: Additional metadata for the entry
            
        Returns:
            True if successfully cached
        """
        try:
            current_time = time.time()
            expires_at = None if ttl is None else current_time + (ttl or self.default_ttl)
            
            entry = CacheEntry(
                key=key,
                data=value,
                created_at=current_time,
                expires_at=expires_at,
                access_count=0,
                last_accessed=current_time,
                metadata=metadata or {}
            )
            
            self._memory_cache[key] = entry
            
            # Enforce limits and save
            self._enforce_size_limits()
            self._save_persistent_cache()
            
            return True
            
        except Exception as e:
            logger.error(f"Error caching value: {e}")
            return False
            
    def delete(self, key: str) -> bool:
        """Delete a cache entry."""
        if key in self._memory_cache:
            del self._memory_cache[key]
            self._save_persistent_cache()
            return True
        return False
        
    def clear(self) -> int:
        """Clear all cache entries."""
        count = len(self._memory_cache)
        self._memory_cache.clear()
        
        # Remove cache files
        for cache_file in [self.ai_cache_file, self.config_cache_file]:
            try:
                if cache_file.exists():
                    cache_file.unlink()
            except OSError as e:
                logger.warning(f"Could not remove cache file {cache_file}: {e}")
                
        return count
        
    def get_stats(self) -> Dict[str, Any]:
        """Get cache statistics."""
        current_time = time.time()
        
        total_entries = len(self._memory_cache)
        expired_entries = sum(
            1 for entry in self._memory_cache.values()
            if self._is_expired(entry)
        )
        
        total_size = 0
        total_access_count = 0
        oldest_entry = current_time
        newest_entry = 0
        
        for entry in self._memory_cache.values():
            try:
                entry_size = len(json.dumps(asdict(entry)).encode())
                total_size += entry_size
                total_access_count += entry.access_count
                
                if entry.created_at < oldest_entry:
                    oldest_entry = entry.created_at
                if entry.created_at > newest_entry:
                    newest_entry = entry.created_at
                    
            except (TypeError, ValueError):
                continue
                
        return {
            'total_entries': total_entries,
            'expired_entries': expired_entries,
            'valid_entries': total_entries - expired_entries,
            'total_size_bytes': total_size,
            'total_size_mb': round(total_size / 1024 / 1024, 2),
            'total_access_count': total_access_count,
            'cache_hit_ratio': self._calculate_hit_ratio(),
            'oldest_entry_age_seconds': current_time - oldest_entry if total_entries > 0 else 0,
            'newest_entry_age_seconds': current_time - newest_entry if total_entries > 0 else 0,
            'max_size_mb': round(self.max_size / 1024 / 1024, 2),
            'max_entries': self.max_entries,
            'size_utilization_percent': round((total_size / self.max_size) * 100, 2),
            'entry_utilization_percent': round((total_entries / self.max_entries) * 100, 2)
        }
        
    def _calculate_hit_ratio(self) -> float:
        """Calculate cache hit ratio from stored statistics."""
        try:
            if self.stats_file.exists():
                with open(self.stats_file, 'r', encoding='utf-8') as f:
                    stats = json.load(f)
                    hits = stats.get('cache_hits', 0)
                    misses = stats.get('cache_misses', 0)
                    total = hits + misses
                    return round(hits / total * 100, 2) if total > 0 else 0.0
        except (json.JSONDecodeError, OSError, ZeroDivisionError):
            pass
        return 0.0
        
    def cache_ai_response(self, prompt: str, model_params: Dict[str, Any], 
                         response: str, ttl: Optional[int] = None) -> str:
        """
        Cache an AI response with automatic key generation.
        
        Args:
            prompt: The prompt used for the AI request
            model_params: Model parameters used
            response: The AI response to cache
            ttl: Time-to-live for the cache entry
            
        Returns:
            The cache key used
        """
        cache_key_data = {
            'prompt_hash': hashlib.sha256(prompt.encode()).hexdigest()[:16],
            'model_params': model_params
        }
        
        cache_key = self._generate_cache_key(cache_key_data)
        
        metadata = {
            'type': 'ai_response',
            'prompt_length': len(prompt),
            'response_length': len(response),
            'model': model_params.get('model_name', 'unknown')
        }
        
        self.set(cache_key, response, ttl or self.default_ttl, metadata)
        return cache_key
        
    def get_cached_ai_response(self, prompt: str, model_params: Dict[str, Any]) -> Optional[str]:
        """
        Get a cached AI response.
        
        Args:
            prompt: The prompt to look up
            model_params: Model parameters used
            
        Returns:
            Cached response or None if not found
        """
        cache_key_data = {
            'prompt_hash': hashlib.sha256(prompt.encode()).hexdigest()[:16],
            'model_params': model_params
        }
        
        cache_key = self._generate_cache_key(cache_key_data)
        return self.get(cache_key)
        
    def warm_cache(self, warm_data: List[Dict[str, Any]]) -> int:
        """
        Pre-populate cache with frequently used data.
        
        Args:
            warm_data: List of cache entries to pre-populate
            
        Returns:
            Number of entries successfully cached
        """
        cached_count = 0
        
        for entry in warm_data:
            try:
                key = entry.get('key') or self._generate_cache_key(entry['data'])
                ttl = entry.get('ttl', self.default_ttl)
                metadata = entry.get('metadata', {})
                metadata['warmed'] = True
                
                if self.set(key, entry['data'], ttl, metadata):
                    cached_count += 1
                    
            except (KeyError, TypeError) as e:
                logger.warning(f"Invalid warm cache entry: {e}")
                
        return cached_count
        
    def cleanup(self) -> Dict[str, int]:
        """
        Perform cache cleanup and return statistics.
        
        Returns:
            Dictionary with cleanup statistics
        """
        expired_count = self._cleanup_expired()
        
        old_size = len(self._memory_cache) + expired_count
        self._enforce_size_limits()
        size_reduced_count = old_size - len(self._memory_cache) - expired_count
        
        self._save_persistent_cache()
        
        return {
            'expired_entries_removed': expired_count,
            'size_limit_entries_removed': size_reduced_count,
            'total_entries_removed': expired_count + size_reduced_count,
            'remaining_entries': len(self._memory_cache)
        }


# Global cache manager instance
_cache_manager: Optional[CacheManager] = None

def get_cache_manager() -> CacheManager:
    """Get or create the global cache manager instance."""
    global _cache_manager
    if _cache_manager is None:
        _cache_manager = CacheManager()
    return _cache_manager

def cache_ai_response(prompt: str, model_params: Dict[str, Any], 
                     response: str, ttl: Optional[int] = None) -> str:
    """Convenience function for caching AI responses."""
    return get_cache_manager().cache_ai_response(prompt, model_params, response, ttl)

def get_cached_ai_response(prompt: str, model_params: Dict[str, Any]) -> Optional[str]:
    """Convenience function for getting cached AI responses."""
    return get_cache_manager().get_cached_ai_response(prompt, model_params)