"""Audio transcription using faster-whisper."""

from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from faster_whisper import WhisperModel


@dataclass
class TranscriptSegment:
    """A single segment of transcribed audio."""

    start_ms: int
    end_ms: int
    text: str
    confidence: float | None = None


@dataclass
class Transcript:
    """Complete transcription result."""

    segments: list[TranscriptSegment]
    full_text: str
    duration_ms: int
    language: str | None = None


class AudioTranscriber:
    """Transcribes audio files using faster-whisper.

    The model is loaded lazily on first use to improve startup time.
    """

    def __init__(self, model_size: str = "base") -> None:
        """Initialize the transcriber.

        Args:
            model_size: Whisper model size (tiny, base, small, medium, large-v2, large-v3).
                       Defaults to 'base' for a good balance of speed and accuracy.
        """
        self.model_size = model_size
        self._model: WhisperModel | None = None

    def _get_model(self) -> "WhisperModel":
        """Lazily load the whisper model on first use."""
        if self._model is None:
            from faster_whisper import WhisperModel

            # Use CPU by default; faster-whisper will use GPU if available
            self._model = WhisperModel(self.model_size, device="auto", compute_type="auto")
        return self._model

    def transcribe(self, audio_path: str, language: str | None = None) -> Transcript:
        """Transcribe an audio file.

        Args:
            audio_path: Path to the audio file (supports mp3, wav, m4a, etc.)
            language: Optional language code (e.g., 'en', 'es'). Auto-detected if not provided.

        Returns:
            Transcript object with segments and full text.

        Raises:
            FileNotFoundError: If the audio file doesn't exist.
            RuntimeError: If transcription fails.
        """
        import os

        if not os.path.exists(audio_path):
            raise FileNotFoundError(f"Audio file not found: {audio_path}")

        model = self._get_model()

        try:
            segments_iter, info = model.transcribe(
                audio_path,
                language=language,
                beam_size=5,
                word_timestamps=False,
                vad_filter=True,  # Filter out silence
            )

            segments: list[TranscriptSegment] = []
            full_text_parts: list[str] = []

            for segment in segments_iter:
                # Convert seconds to milliseconds
                start_ms = int(segment.start * 1000)
                end_ms = int(segment.end * 1000)
                text = segment.text.strip()

                if text:  # Only include non-empty segments
                    segments.append(
                        TranscriptSegment(
                            start_ms=start_ms,
                            end_ms=end_ms,
                            text=text,
                            confidence=segment.avg_logprob
                            if hasattr(segment, "avg_logprob")
                            else None,
                        )
                    )
                    full_text_parts.append(text)

            # Calculate duration from the last segment or audio info
            duration_ms = int(info.duration * 1000) if info.duration else 0
            if segments and segments[-1].end_ms > duration_ms:
                duration_ms = segments[-1].end_ms

            return Transcript(
                segments=segments,
                full_text=" ".join(full_text_parts),
                duration_ms=duration_ms,
                language=info.language if info.language else None,
            )

        except Exception as e:
            raise RuntimeError(f"Transcription failed: {e}") from e
