from __future__ import annotations

import asyncio
import ctypes
import io
import logging
import multiprocessing as mp
import socket
import time
import uuid
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing.context import BaseContext
from typing import ClassVar

import psutil
import pytest

from livekit.agents import JobContext, JobProcess, ipc, job, utils
from livekit.agents.ipc.log_queue import LogQueueHandler, LogQueueListener
from livekit.agents.utils.aio import duplex_unix
from livekit.protocol import agent

pytestmark = [pytest.mark.unit, pytest.mark.concurrent]


@dataclass
class EmptyMessage:
    MSG_ID: ClassVar[int] = 0


@dataclass
class SomeDataMessage:
    MSG_ID: ClassVar[int] = 1
    string: str = ""
    number: int = 0
    double: float = 0.0
    data: bytes = b""

    def write(self, b: io.BytesIO) -> None:
        ipc.channel.write_string(b, self.string)
        ipc.channel.write_int(b, self.number)
        ipc.channel.write_double(b, self.double)
        ipc.channel.write_bytes(b, self.data)

    def read(self, b: io.BytesIO) -> None:
        self.string = ipc.channel.read_string(b)
        self.number = ipc.channel.read_int(b)
        self.double = ipc.channel.read_double(b)
        self.data = ipc.channel.read_bytes(b)


IPC_MESSAGES = {
    EmptyMessage.MSG_ID: EmptyMessage,
    SomeDataMessage.MSG_ID: SomeDataMessage,
}


def _echo_main(mp_cch):
    async def _pong():
        cch = await utils.aio.duplex_unix._AsyncDuplex.open(mp_cch)
        while True:
            try:
                msg = await ipc.channel.arecv_message(cch, IPC_MESSAGES)
                await ipc.channel.asend_message(cch, msg)
            except utils.aio.duplex_unix.DuplexClosed:
                print("_echo_main, duplex closed..")
                break

    asyncio.run(_pong())


async def test_async_channel():
    mp_pch, mp_cch = socket.socketpair()
    pch = await utils.aio.duplex_unix._AsyncDuplex.open(mp_pch)
    proc = mp.get_context("spawn").Process(target=_echo_main, args=(mp_cch,))
    proc.start()
    mp_cch.close()

    await ipc.channel.asend_message(pch, EmptyMessage())
    assert await ipc.channel.arecv_message(pch, IPC_MESSAGES) == EmptyMessage()

    await ipc.channel.asend_message(
        pch, SomeDataMessage(string="hello", number=42, double=3.14, data=b"world")
    )
    assert await ipc.channel.arecv_message(pch, IPC_MESSAGES) == SomeDataMessage(
        string="hello", number=42, double=3.14, data=b"world"
    )

    await pch.aclose()
    await asyncio.sleep(0.1)
    proc.terminate()
    await asyncio.get_running_loop().run_in_executor(None, proc.join)


def test_sync_channel():
    mp_pch, mp_cch = socket.socketpair()
    pch = utils.aio.duplex_unix._Duplex.open(mp_pch)

    proc = mp.get_context("spawn").Process(target=_echo_main, args=(mp_cch,))
    proc.start()
    mp_cch.close()

    ipc.channel.send_message(pch, EmptyMessage())
    assert ipc.channel.recv_message(pch, IPC_MESSAGES) == EmptyMessage()

    ipc.channel.send_message(
        pch, SomeDataMessage(string="hello", number=42, double=3.14, data=b"world")
    )
    assert ipc.channel.recv_message(pch, IPC_MESSAGES) == SomeDataMessage(
        string="hello", number=42, double=3.14, data=b"world"
    )

    pch.close()


def _generate_fake_job() -> job.RunningJobInfo:
    return job.RunningJobInfo(
        job=agent.Job(id="fake_job_" + str(uuid.uuid4().hex), type=agent.JobType.JT_ROOM),
        url="fake_url",
        token="fake_token",
        accept_arguments=job.JobAcceptArguments(name="", identity="", metadata=""),
        worker_id="fake_id",
        fake_job=True,
    )


@dataclass
class _StartArgs:
    initialize_counter: mp.Value
    entrypoint_counter: mp.Value
    shutdown_counter: mp.Value
    initialize_simulate_work_time: float
    entrypoint_simulate_work_time: float
    shutdown_simulate_work_time: float
    update_ev: mp.Condition


def _new_start_args(mp_ctx: BaseContext) -> _StartArgs:
    return _StartArgs(
        initialize_counter=mp_ctx.Value(ctypes.c_uint),
        entrypoint_counter=mp_ctx.Value(ctypes.c_uint),
        shutdown_counter=mp_ctx.Value(ctypes.c_uint),
        initialize_simulate_work_time=0.0,
        entrypoint_simulate_work_time=0.0,
        shutdown_simulate_work_time=0.0,
        update_ev=mp_ctx.Condition(),
    )


def _initialize_proc(proc: JobProcess) -> None:
    start_args: _StartArgs = proc.user_arguments

    # incrementing isn't atomic (the lock should be reentrant by default)
    with start_args.initialize_counter.get_lock():
        start_args.initialize_counter.value += 1

    time.sleep(start_args.initialize_simulate_work_time)

    with start_args.update_ev:
        start_args.update_ev.notify()


def _failing_initialize_proc(proc: JobProcess) -> None:
    # Runs in the spawned child; raising here makes every spawn's initialize() fail. Used instead
    # of monkeypatching ProcJobExecutor process-wide, so the test is safe to run concurrently.
    raise RuntimeError("simulated init failure")


async def _job_entrypoint(job_ctx: JobContext) -> None:
    start_args: _StartArgs = job_ctx.proc.user_arguments

    async def _job_shutdown() -> None:
        with start_args.shutdown_counter.get_lock():
            start_args.shutdown_counter.value += 1

        await asyncio.sleep(start_args.shutdown_simulate_work_time)

        with start_args.update_ev:
            start_args.update_ev.notify()

    job_ctx.add_shutdown_callback(_job_shutdown)

    with start_args.entrypoint_counter.get_lock():
        start_args.entrypoint_counter.value += 1

    await asyncio.sleep(start_args.entrypoint_simulate_work_time)

    job_ctx.shutdown(
        "calling shutdown inside the test to avoid a warning when neither shutdown nor connect is called."  # noqa: E501
    )

    with start_args.update_ev:
        start_args.update_ev.notify()


async def _job_entrypoint_session_aclose_hangs(job_ctx: JobContext) -> None:
    """Plant a fake AgentSession whose aclose() hangs forever; the
    _SESSION_ACLOSE_TIMEOUT guardrail must fire so shutdown callbacks still run."""
    from livekit.agents.ipc import job_proc_lazy_main

    # Shrink the hardcoded guardrail so the test doesn't wait the full 60s.
    job_proc_lazy_main._SESSION_ACLOSE_TIMEOUT = 0.5

    start_args: _StartArgs = job_ctx.proc.user_arguments

    async def _job_shutdown() -> None:
        with start_args.shutdown_counter.get_lock():
            start_args.shutdown_counter.value += 1

    job_ctx.add_shutdown_callback(_job_shutdown)

    class _HangingSession:
        async def aclose(self) -> None:
            await asyncio.Event().wait()  # never resolves

    job_ctx._primary_agent_session = _HangingSession()  # type: ignore[assignment]

    # _on_session_end touches session.history and make_session_report, which
    # would crash on our minimal fake. Bypass it so the assertion is purely
    # about the aclose-timeout → shutdown-callback path.
    async def _noop_session_end() -> None:
        return None

    job_ctx._on_session_end = _noop_session_end  # type: ignore[method-assign]

    with start_args.entrypoint_counter.get_lock():
        start_args.entrypoint_counter.value += 1

    job_ctx.shutdown("trigger hang in session.aclose()")


async def _job_entrypoint_raises_after_shutdown(job_ctx: JobContext) -> None:
    """Reproduces the room-disconnect race: _shutdown_fut is set before the
    entrypoint task unwinds, then the entrypoint raises while _run_job_task
    is awaiting it. The shutdown callback must still run."""
    start_args: _StartArgs = job_ctx.proc.user_arguments

    async def _job_shutdown() -> None:
        with start_args.shutdown_counter.get_lock():
            start_args.shutdown_counter.value += 1

    job_ctx.add_shutdown_callback(_job_shutdown)

    with start_args.entrypoint_counter.get_lock():
        start_args.entrypoint_counter.value += 1

    # set _shutdown_fut first (mimics the room "disconnected" handler), then
    # raise from a pending await (mimics wait_for_participant's RuntimeError).
    fut: asyncio.Future[None] = asyncio.Future()

    async def _trigger() -> None:
        job_ctx.shutdown("simulated room disconnect")
        await asyncio.sleep(0)
        fut.set_exception(RuntimeError("room disconnected while waiting for participant"))

    asyncio.create_task(_trigger())
    await fut


async def _poll_until(
    condition_fn: Callable[[], bool], *, timeout: float = 10.0, poll_interval: float = 0.05
) -> None:
    deadline = time.monotonic() + timeout
    while time.monotonic() < deadline:
        if condition_fn():
            return
        await asyncio.sleep(poll_interval)
    raise TimeoutError(f"Timed out after {timeout}s waiting for condition")


async def _wait_for_elements(q: asyncio.Queue, num_elements: int) -> None:
    for _ in range(num_elements):
        await q.get()


async def test_proc_pool():
    mp_ctx = mp.get_context("spawn")
    loop = asyncio.get_running_loop()
    num_idle_processes = 3
    pool = ipc.proc_pool.ProcPool(
        initialize_process_fnc=_initialize_proc,
        job_entrypoint_fnc=_job_entrypoint,
        session_end_fnc=None,
        simulation_end_fnc=None,
        num_idle_processes=num_idle_processes,
        job_executor_type=job.JobExecutorType.PROCESS,
        initialize_timeout=20.0,
        close_timeout=20.0,
        session_end_timeout=300.0,
        inference_executor=None,
        memory_warn_mb=0,
        memory_limit_mb=0,
        http_proxy=None,
        mp_ctx=mp_ctx,
        loop=loop,
    )

    start_args = _new_start_args(mp_ctx)
    created_q = asyncio.Queue()
    start_q = asyncio.Queue()
    ready_q = asyncio.Queue()
    close_q = asyncio.Queue()

    pids = []
    exitcodes = []

    @pool.on("process_created")
    def _process_created(proc: ipc.job_proc_executor.ProcJobExecutor):
        created_q.put_nowait(None)
        proc.user_arguments = start_args

    @pool.on("process_started")
    def _process_started(proc: ipc.job_proc_executor.ProcJobExecutor):
        start_q.put_nowait(None)
        pids.append(proc.pid)

    @pool.on("process_ready")
    def _process_ready(proc: ipc.job_proc_executor.ProcJobExecutor):
        ready_q.put_nowait(None)

    @pool.on("process_closed")
    def _process_closed(proc: ipc.job_proc_executor.ProcJobExecutor):
        close_q.put_nowait(None)
        exitcodes.append(proc.exitcode)

    await pool.start()

    await _wait_for_elements(created_q, num_idle_processes)
    await _wait_for_elements(start_q, num_idle_processes)
    await _wait_for_elements(ready_q, num_idle_processes)

    assert start_args.initialize_counter.value == num_idle_processes

    jobs_to_start = 2

    for _ in range(jobs_to_start):
        await pool.launch_job(_generate_fake_job())

    await _wait_for_elements(created_q, jobs_to_start)
    await _wait_for_elements(start_q, jobs_to_start)
    await _wait_for_elements(ready_q, jobs_to_start)

    await pool.aclose()

    assert start_args.entrypoint_counter.value == jobs_to_start
    assert start_args.shutdown_counter.value == jobs_to_start

    await _wait_for_elements(close_q, num_idle_processes + jobs_to_start)

    # the way we check that a process doesn't exist anymore isn't technically reliable (pid recycle could happen)  # noqa: E501
    for pid in pids:
        assert not psutil.pid_exists(pid)

    for exitcode in exitcodes:
        # this test expects graceful shutdown, kill is tested on another test
        assert exitcode == 0, f"process did not exit cleanly: {exitcode}"


async def test_slow_initialization():
    mp_ctx = mp.get_context("spawn")
    loop = asyncio.get_running_loop()
    num_idle_processes = 2
    pool = ipc.proc_pool.ProcPool(
        job_executor_type=job.JobExecutorType.PROCESS,
        initialize_process_fnc=_initialize_proc,
        job_entrypoint_fnc=_job_entrypoint,
        session_end_fnc=None,
        simulation_end_fnc=None,
        num_idle_processes=num_idle_processes,
        initialize_timeout=1.0,
        close_timeout=20.0,
        session_end_timeout=300.0,
        inference_executor=None,
        memory_warn_mb=0,
        memory_limit_mb=0,
        http_proxy=None,
        mp_ctx=mp_ctx,
        loop=loop,
    )

    start_args = _new_start_args(mp_ctx)
    start_args.initialize_simulate_work_time = 2.0
    start_q = asyncio.Queue()
    close_q = asyncio.Queue()

    pids = []
    exitcodes = []

    @pool.on("process_created")
    def _process_created(proc: ipc.job_proc_executor.ProcJobExecutor):
        proc.user_arguments = start_args
        start_q.put_nowait(None)

    @pool.on("process_closed")
    def _process_closed(proc: ipc.job_proc_executor.ProcJobExecutor):
        close_q.put_nowait(None)
        if proc.pid is not None:
            pids.append(proc.pid)
        exitcodes.append(proc.exitcode)

    await pool.start()

    await _wait_for_elements(start_q, num_idle_processes)
    await _wait_for_elements(close_q, num_idle_processes)

    # retry batch should also timeout and be killed
    await _wait_for_elements(start_q, num_idle_processes)
    await _wait_for_elements(close_q, num_idle_processes)

    await pool.aclose()

    for pid in pids:
        assert not psutil.pid_exists(pid)

    for exitcode in exitcodes:
        assert exitcode != 0, "process should have been killed"


async def test_proc_pool_launch_job_raises_when_all_spawns_fail():
    """When every spawn task fails to initialize, launch_job should raise
    instead of hanging on an empty warmed-process queue. Reproduces #5868.

    The failure is injected via a real executor whose init fn raises (not by
    monkeypatching ProcJobExecutor process-wide), so this is safe to run
    concurrently with its peers.
    """
    mp_ctx = mp.get_context("spawn")
    loop = asyncio.get_running_loop()
    pool = ipc.proc_pool.ProcPool(
        job_executor_type=job.JobExecutorType.PROCESS,
        initialize_process_fnc=_failing_initialize_proc,
        job_entrypoint_fnc=_job_entrypoint,
        session_end_fnc=None,
        simulation_end_fnc=None,
        num_idle_processes=0,
        # generous so initialize() fails via the init fn raising, not a spawn-racing timeout
        initialize_timeout=10.0,
        close_timeout=20.0,
        session_end_timeout=300.0,
        inference_executor=None,
        memory_warn_mb=0,
        memory_limit_mb=0,
        http_proxy=None,
        mp_ctx=mp_ctx,
        loop=loop,
    )
    await pool.start()

    try:
        with pytest.raises(RuntimeError, match="no process became available"):
            await pool.launch_job(_generate_fake_job())

        assert pool._jobs_waiting_for_process == 0
        assert len(pool.processes) == 0
    finally:
        await pool.aclose()


def _create_proc(
    *,
    close_timeout: float,
    mp_ctx: BaseContext,
    initialize_timeout: float = 20.0,
    job_entrypoint_fnc: Callable[[JobContext], object] = _job_entrypoint,
) -> tuple[ipc.job_proc_executor.ProcJobExecutor, _StartArgs]:
    start_args = _new_start_args(mp_ctx)
    loop = asyncio.get_running_loop()
    proc = ipc.job_proc_executor.ProcJobExecutor(
        initialize_process_fnc=_initialize_proc,
        job_entrypoint_fnc=job_entrypoint_fnc,
        session_end_fnc=None,
        simulation_end_fnc=None,
        initialize_timeout=initialize_timeout,
        close_timeout=close_timeout,
        session_end_timeout=300.0,
        memory_warn_mb=0,
        memory_limit_mb=0,
        ping_interval=2.5,
        ping_timeout=10.0,
        high_ping_threshold=1.0,
        inference_executor=None,
        http_proxy=None,
        mp_ctx=mp_ctx,
        loop=loop,
    )
    proc.user_arguments = start_args
    return proc, start_args


async def test_aclose_after_cancelled_start():
    """When start() is cancelled mid-flight, the shielded _start() keeps running;
    aclose() must still tear down the supervise task and the child process."""
    mp_ctx = mp.get_context("spawn")
    proc, _ = _create_proc(close_timeout=10.0, mp_ctx=mp_ctx)

    start_task = asyncio.create_task(proc.start())
    await asyncio.sleep(0)  # let start() enter the shielded _start()
    start_task.cancel()
    with pytest.raises(asyncio.CancelledError):
        await start_task

    # mimics ProcPool._proc_spawn_task's cleanup after swallowing the cancellation
    await proc.aclose()

    assert proc._supervise_atask is not None, "start() should have completed under the shield"
    assert proc._supervise_atask.done()
    assert proc.pid is not None
    assert not psutil.pid_exists(proc.pid)


async def test_shutdown_no_job():
    mp_ctx = mp.get_context("spawn")
    proc, start_args = _create_proc(close_timeout=10.0, mp_ctx=mp_ctx)
    await proc.start()
    await proc.initialize()
    await proc.aclose()

    assert proc.exitcode == 0
    assert not proc.killed
    assert start_args.shutdown_counter.value == 0, "shutdown_cb isn't called when there is no job"


async def test_job_slow_shutdown():
    mp_ctx = mp.get_context("spawn")
    proc, start_args = _create_proc(close_timeout=0.3, mp_ctx=mp_ctx)
    start_args.shutdown_simulate_work_time = 10.0

    await proc.start()
    await proc.initialize()

    fake_job = _generate_fake_job()
    await proc.launch_job(fake_job)
    await _poll_until(lambda: start_args.entrypoint_counter.value >= 1)
    await proc.aclose()

    # process is killed when there is a job with slow timeout
    assert proc.exitcode != 0, "process should have been killed"
    assert proc.killed


async def test_shutdown_callback_runs_when_session_aclose_hangs():
    """Regression test: when AgentSession.aclose() blocks indefinitely during
    job shutdown, the _SESSION_ACLOSE_TIMEOUT guardrail must fire and
    user-registered shutdown callbacks must still run. The entrypoint shrinks
    the hardcoded constant to 0.5s so the test stays fast."""
    mp_ctx = mp.get_context("spawn")
    proc, start_args = _create_proc(
        close_timeout=20.0,
        mp_ctx=mp_ctx,
        job_entrypoint_fnc=_job_entrypoint_session_aclose_hangs,
    )
    await proc.start()
    await proc.initialize()

    fake_job = _generate_fake_job()
    await proc.launch_job(fake_job)
    await _poll_until(lambda: start_args.entrypoint_counter.value >= 1)
    await proc.aclose()

    assert proc.exitcode == 0, "process should have exited cleanly"
    assert not proc.killed
    assert start_args.shutdown_counter.value == 1, (
        "shutdown callback must run even when session.aclose() hangs"
    )


async def test_shutdown_callback_runs_when_entrypoint_raises():
    """Regression test: when the entrypoint raises after _shutdown_fut is
    already set (as happens on room disconnect mid-wait_for_participant),
    registered shutdown callbacks must still run."""
    mp_ctx = mp.get_context("spawn")
    proc, start_args = _create_proc(
        close_timeout=10.0,
        mp_ctx=mp_ctx,
        job_entrypoint_fnc=_job_entrypoint_raises_after_shutdown,
    )
    await proc.start()
    await proc.initialize()

    fake_job = _generate_fake_job()
    await proc.launch_job(fake_job)
    await _poll_until(lambda: start_args.entrypoint_counter.value >= 1)
    await proc.aclose()

    assert proc.exitcode == 0, "process should have exited cleanly"
    assert not proc.killed
    assert start_args.shutdown_counter.value == 1, (
        "shutdown callback must run even when entrypoint raises"
    )


async def test_job_graceful_shutdown():
    mp_ctx = mp.get_context("spawn")
    proc, start_args = _create_proc(close_timeout=10.0, mp_ctx=mp_ctx)
    start_args.shutdown_simulate_work_time = 0.3
    await proc.start()
    await proc.initialize()

    fake_job = _generate_fake_job()
    await proc.launch_job(fake_job)
    await _poll_until(lambda: start_args.entrypoint_counter.value >= 1)
    await proc.aclose()

    assert proc.exitcode == 0, "process should have exited cleanly"
    assert not proc.killed
    assert start_args.shutdown_counter.value == 1


def test_log_queue_drains_before_stop():
    """All log records must be received by the listener even when stop() is
    called right after the sender closes its end.  This reproduces a race where
    LogQueueListener.stop() used to close the socket *before* joining the
    monitor thread, dropping buffered records."""
    NUM_LOGS = 200
    received: list[str] = []

    parent_sock, child_sock = socket.socketpair()
    parent_dup = duplex_unix._Duplex.open(parent_sock)
    child_dup = duplex_unix._Duplex.open(child_sock)

    # -- parent (listener) side --
    class _CapturingListener(LogQueueListener):
        def handle(self, record: logging.LogRecord) -> None:
            received.append(record.getMessage())
            # slow down processing so the buffer is not fully drained
            # before stop() is called
            time.sleep(0.001)

    listener = _CapturingListener(parent_dup, lambda r: None)
    listener.start()

    # -- child (handler) side --
    handler = LogQueueHandler(child_dup)
    test_logger = logging.getLogger("test_log_queue_drain")
    test_logger.addHandler(handler)
    test_logger.setLevel(logging.DEBUG)
    test_logger.propagate = False

    for i in range(NUM_LOGS):
        test_logger.info("msg %d", i)

    # simulate the child process shutting down: close handler then its thread
    handler.close()
    handler.thread.join()
    # child duplex is now closed by _forward_logs

    # simulate supervised_proc._sync_run: proc.join() returned, now stop listener
    listener.stop()

    test_logger.removeHandler(handler)

    assert len(received) == NUM_LOGS, (
        f"Expected {NUM_LOGS} records, got {len(received)}. "
        f"Lost {NUM_LOGS - len(received)} tail log records."
    )
