import threading
import time
from collections.abc import Callable
from typing import Any

try:
    import AppKit
    import Quartz
    from AppKit import NSEvent, NSWorkspace
except ImportError:
    AppKit = None
    Quartz = None
    NSEvent = None
    NSWorkspace = None


MODIFIER_FLAGS = {
    "cmd": int(getattr(AppKit, "NSEventModifierFlagCommand", 1 << 20)) if AppKit else 1 << 20,
    "alt": int(getattr(AppKit, "NSEventModifierFlagOption", 1 << 19)) if AppKit else 1 << 19,
    "ctrl": int(getattr(AppKit, "NSEventModifierFlagControl", 1 << 18)) if AppKit else 1 << 18,
    "shift": int(getattr(AppKit, "NSEventModifierFlagShift", 1 << 17)) if AppKit else 1 << 17,
}

SPECIAL_KEYCODES = {
    36: "enter",
    48: "tab",
    49: "space",
    51: "backspace",
    53: "escape",
    123: "left",
    124: "right",
    125: "down",
    126: "up",
}


def _modifier_names(flags: int) -> list[str]:
    names: list[str] = []
    for name, mask in MODIFIER_FLAGS.items():
        if flags & mask:
            names.append(name)
    return names


def _get_frontmost_window_context() -> tuple[str | None, str | None]:
    if not Quartz:
        return None, None

    app_name: str | None = None
    try:
        if NSWorkspace is not None:
            frontmost = NSWorkspace.sharedWorkspace().frontmostApplication()
            if frontmost is not None:
                app_name = frontmost.localizedName() or None
    except Exception:
        app_name = None

    # Prefer Quartz window ordering because it reflects the topmost on-screen window,
    # which is more reliable than NSWorkspace frontmost app in some background contexts.
    try:
        options = (
            Quartz.kCGWindowListOptionOnScreenOnly | Quartz.kCGWindowListExcludeDesktopElements
        )
        windows = Quartz.CGWindowListCopyWindowInfo(options, Quartz.kCGNullWindowID) or []
        ignored_owners = {"Window Server", "Dock", "Control Center"}

        for window in windows:
            owner = window.get("kCGWindowOwnerName") or ""
            if owner in ignored_owners:
                continue
            if int(window.get("kCGWindowLayer", 1)) != 0:
                continue
            alpha = float(window.get("kCGWindowAlpha", 1.0))
            if alpha <= 0:
                continue

            bounds = window.get("kCGWindowBounds", {})
            width = int(bounds.get("Width", 0))
            height = int(bounds.get("Height", 0))
            if width <= 1 or height <= 1:
                continue

            title = window.get("kCGWindowName") or owner
            return owner or app_name, title
    except Exception:
        pass

    return app_name, app_name


class EventTapWatcher:
    def __init__(self, callback: Callable, raw_key_logging: bool = False):
        self.callback = callback
        self.raw_key_logging = raw_key_logging
        self._thread: threading.Thread | None = None
        self._context_thread: threading.Thread | None = None
        self._run_loop = None
        self._tap = None
        self._ready = threading.Event()
        self._stop = threading.Event()
        self._context_lock = threading.Lock()
        self._context_cache: dict[str, Any] = {"at": 0.0, "app": None, "window": None}
        self._last_context_emitted: tuple[str | None, str | None] | None = None
        self._tap_ready = False

    def start(self) -> "EventTapWatcher":
        self._thread = threading.Thread(target=self._run_loop_target, daemon=True)
        self._thread.start()
        self._ready.wait(timeout=1.0)
        self._context_thread = threading.Thread(target=self._context_poll_loop, daemon=True)
        self._context_thread.start()
        return self

    def _get_context(self) -> tuple[str | None, str | None]:
        now = time.monotonic()
        with self._context_lock:
            if now - float(self._context_cache["at"]) < 0.08:
                return self._context_cache["app"], self._context_cache["window"]

        app_name, window_title = _get_frontmost_window_context()
        with self._context_lock:
            self._context_cache = {"at": now, "app": app_name, "window": window_title}
        return app_name, window_title

    def _emit(self, payload: dict[str, Any]) -> None:
        app_name, window_title = self._get_context()
        payload.setdefault("timestamp_ms", int(time.time() * 1000))
        payload["app"] = app_name
        payload["window"] = window_title
        self.callback(payload)

    def _emit_with_context(
        self, payload: dict[str, Any], app_name: str | None, window_title: str | None
    ) -> None:
        payload.setdefault("timestamp_ms", int(time.time() * 1000))
        payload["app"] = app_name
        payload["window"] = window_title
        self.callback(payload)

    def _context_poll_loop(self) -> None:
        while not self._stop.is_set():
            app_name, window_title = self._get_context()
            context = (app_name, window_title)
            if context != self._last_context_emitted:
                self._last_context_emitted = context
                self._emit_with_context({"type": "app_focus"}, app_name, window_title)
            time.sleep(0.12)

    def _serialize_key_event(self, event) -> dict[str, Any] | None:
        ns_event = NSEvent.eventWithCGEvent_(event) if NSEvent else None
        if not ns_event:
            return None

        flags = int(ns_event.modifierFlags())
        modifiers = _modifier_names(flags)
        keycode = int(Quartz.CGEventGetIntegerValueField(event, Quartz.kCGKeyboardEventKeycode))

        chars = (ns_event.charactersIgnoringModifiers() or ns_event.characters() or "").strip()
        if len(chars) == 1 and chars.isprintable():
            key_value = chars if self.raw_key_logging else "<redacted>"
            return {
                "type": "shortcut" if modifiers else "keypress",
                "key_type": "alphanumeric",
                "key": key_value,
                "modifiers": modifiers,
                "redacted": not self.raw_key_logging,
                "keycode": keycode,
            }

        return {
            "type": "keypress",
            "key_type": "special",
            "key": SPECIAL_KEYCODES.get(keycode, f"keycode_{keycode}"),
            "modifiers": modifiers,
            "redacted": False,
            "keycode": keycode,
        }

    def _handle_tap_event(self, _proxy, event_type, event, _refcon):
        if event_type in (
            Quartz.kCGEventTapDisabledByTimeout,
            Quartz.kCGEventTapDisabledByUserInput,
        ):
            if self._tap is not None:
                Quartz.CGEventTapEnable(self._tap, True)
            return event

        if event_type in (
            Quartz.kCGEventLeftMouseDown,
            Quartz.kCGEventRightMouseDown,
            Quartz.kCGEventOtherMouseDown,
        ):
            loc = Quartz.CGEventGetLocation(event)
            button = {
                Quartz.kCGEventLeftMouseDown: "left",
                Quartz.kCGEventRightMouseDown: "right",
                Quartz.kCGEventOtherMouseDown: "other",
            }.get(event_type, "other")
            self._emit(
                {
                    "type": "click",
                    "x": int(loc.x),
                    "y": int(loc.y),
                    "button": button,
                }
            )
            return event

        if event_type == Quartz.kCGEventScrollWheel:
            dy = int(
                Quartz.CGEventGetIntegerValueField(event, Quartz.kCGScrollWheelEventDeltaAxis1)
            )
            dx = int(
                Quartz.CGEventGetIntegerValueField(event, Quartz.kCGScrollWheelEventDeltaAxis2)
            )
            if dx == 0 and dy == 0:
                return event
            loc = Quartz.CGEventGetLocation(event)
            self._emit(
                {
                    "type": "mouse_scroll",
                    "x": int(loc.x),
                    "y": int(loc.y),
                    "dx": dx,
                    "dy": dy,
                    "direction": "down" if dy < 0 else "up",
                }
            )
            return event

        if event_type == Quartz.kCGEventKeyDown:
            payload = self._serialize_key_event(event)
            if payload:
                self._emit(payload)

        return event

    def _run_loop_target(self) -> None:
        mask = (
            Quartz.CGEventMaskBit(Quartz.kCGEventLeftMouseDown)
            | Quartz.CGEventMaskBit(Quartz.kCGEventRightMouseDown)
            | Quartz.CGEventMaskBit(Quartz.kCGEventOtherMouseDown)
            | Quartz.CGEventMaskBit(Quartz.kCGEventScrollWheel)
            | Quartz.CGEventMaskBit(Quartz.kCGEventKeyDown)
        )

        self._tap = Quartz.CGEventTapCreate(
            Quartz.kCGSessionEventTap,
            Quartz.kCGHeadInsertEventTap,
            Quartz.kCGEventTapOptionListenOnly,
            mask,
            self._handle_tap_event,
            None,
        )
        self._ready.set()
        if not self._tap:
            print("Failed to create event tap")
            return
        self._tap_ready = True

        source_factory = getattr(Quartz, "CFMachPortCreateRunLoopSource", None)
        if source_factory is None:
            print(
                "CFMachPortCreateRunLoopSource unavailable; cannot attach event tap run loop source"
            )
            return
        run_loop_source = source_factory(None, self._tap, 0)
        self._run_loop = Quartz.CFRunLoopGetCurrent()
        Quartz.CFRunLoopAddSource(self._run_loop, run_loop_source, Quartz.kCFRunLoopCommonModes)
        Quartz.CGEventTapEnable(self._tap, True)
        Quartz.CFRunLoopRun()

    def stop(self) -> None:
        if not Quartz:
            return
        self._stop.set()
        if self._tap is not None:
            try:
                Quartz.CGEventTapEnable(self._tap, False)
            except Exception:
                pass
        if self._run_loop is not None:
            try:
                Quartz.CFRunLoopStop(self._run_loop)
            except Exception:
                pass
        if self._thread and self._thread.is_alive():
            self._thread.join(timeout=1.0)
        if self._context_thread and self._context_thread.is_alive():
            self._context_thread.join(timeout=1.0)

    def is_active(self) -> bool:
        return bool(self._tap_ready)


def start_event_capture(callback: Callable):
    if not Quartz:
        print("Quartz not available")
        return None
    watcher = EventTapWatcher(callback=callback, raw_key_logging=False)
    return watcher.start()
