import os
from typing import Any

import pytest

from livekit.agents import inference
from livekit.agents.llm import AgentHandoff, ChatContext, FunctionCall, FunctionCallOutput, utils
from livekit.agents.types import (
    DEFAULT_API_CONNECT_OPTIONS,
    NOT_GIVEN,
    APIConnectOptions,
    NotGivenOr,
)

from .fake_llm import FakeLLM, FakeLLMResponse

pytestmark = [pytest.mark.unit, pytest.mark.concurrent]


def ai_function1(a: int, b: str = "default") -> None:
    """
    This is a test function
    Args:
        a: First argument
        b: Second argument
    """
    pass


def skip_if_no_credentials():
    required_vars = ["LIVEKIT_API_KEY", "LIVEKIT_API_SECRET"]
    missing = [var for var in required_vars if not os.getenv(var)]
    return pytest.mark.skipif(
        bool(missing), reason=f"Missing environment variables: {', '.join(missing)}"
    )


def test_args_model():
    from docstring_parser import parse_from_object

    docstring = parse_from_object(ai_function1)
    print(docstring.description)

    model = utils.function_arguments_to_pydantic_model(ai_function1)
    print(model.model_json_schema())


def test_dict():
    from livekit import rtc
    from livekit.agents.beta import Instructions
    from livekit.agents.llm import ChatContext, ImageContent

    chat_ctx = ChatContext()
    chat_ctx.add_message(
        role="system",
        content=Instructions(
            "You are a helpful assistant in audio mode.",
            text="You are a helpful assistant in text mode.",
        ),
    )
    chat_ctx.add_message(
        role="user",
        content="Hello, world!",
    )
    chat_ctx.add_message(
        role="assistant",
        content="Hello, world!",
    )
    chat_ctx.add_message(
        role="user",
        content=[
            ImageContent(
                image=rtc.VideoFrame(64, 64, rtc.VideoBufferType.RGB24, b"0" * 64 * 64 * 3)
            )
        ],
    )
    print(chat_ctx.to_dict())
    print(chat_ctx.items)
    print(ChatContext.from_dict(chat_ctx.to_dict()).items)


def test_chat_ctx_can_be_serialized_and_deserialized_with_defaults():
    from livekit.agents.llm import AgentHandoff, ChatContext, ChatMessage

    items = [
        AgentHandoff(new_agent_id="default_agent", old_agent_id=None),
        ChatMessage(role="user", content=["Hello, world!"]),
        ChatMessage(role="assistant", content=["Hi there!"]),
    ]
    chat_ctx = ChatContext(items)
    assert chat_ctx.is_equivalent(ChatContext.from_dict(chat_ctx.to_dict()))


@skip_if_no_credentials()
async def test_summarize():
    from livekit.agents import ChatContext

    chat_ctx = ChatContext()
    chat_ctx.add_message(
        role="system",
        content=(
            "You are SupportGPT, a customer service agent for Acme Audio. "
            "Gather identifying info first, then troubleshoot. "
            "Only promise replacements if the device is under warranty. "
            "Use the provided tools for order lookup, warranty checks, and RMA creation. "
            "If a return is required, hand off to ReturnsAgent for shipping label logistics."
        ),
    )
    chat_ctx.add_message(
        role="user",
        content=(
            "Hi, I need help with an order I placed last week. The earbuds I got "
            "keep disconnecting and the left side sounds crackly."
        ),
    )
    chat_ctx.add_message(
        role="assistant",
        content=(
            "I can help with that! First, could you share your full name and the email "
            "you used at checkout, so I can locate your order?"
        ),
    )
    chat_ctx.add_message(
        role="user", content=("Sure—I'm Maya Chen, and I used maya.chen+shop@gmail.com.")
    )
    chat_ctx.add_message(
        role="assistant",
        content=("Thanks, Maya. Do you also have the order number and approximate purchase date?"),
    )
    chat_ctx.add_message(role="user", content=("Order #LK-4821936. I bought them on October 7."))

    chat_ctx.items.append(
        FunctionCall(
            name="lookup_order",
            call_id="call_lookup_order_1",
            arguments='{"order_number": "LK-4821936", "customer_email": "maya.chen+shop@gmail.com"}',
        )
    )
    chat_ctx.items.append(
        FunctionCallOutput(
            name="lookup_order",
            call_id="call_lookup_order_1",
            output=(
                '{"order_number":"LK-4821936","customer_name":"Maya Chen","'
                'items":[{"sku":"AC-EBD-PRO","name":"Acme Buds Pro","qty":1}],'
                '"purchase_date":"2025-10-07","status":"delivered","serial":"ACB-PRO-7F29D4"}'
            ),
            is_error=False,
        )
    )
    chat_ctx.add_message(
        role="assistant",
        content=(
            "I found your order LK-4821936 for Acme Buds Pro, delivered October 8. "
            "To check warranty and next steps, which device are you pairing with and what OS version?"
        ),
    )
    chat_ctx.add_message(role="user", content="iPhone 14 Pro, iOS 18.0.1.")
    chat_ctx.add_message(
        role="assistant",
        content=(
            "Thanks. Have you tried any troubleshooting—resetting the buds, forgetting/re-pairing Bluetooth, "
            "or testing another device?"
        ),
    )
    chat_ctx.add_message(
        role="user",
        content=(
            "I tried forgetting and re-pairing twice. I also tested on my iPad and the left ear still crackles."
        ),
    )
    chat_ctx.add_message(
        role="assistant",
        content=("Understood. Any visible damage or signs of moisture? And when did it start?"),
    )
    chat_ctx.add_message(
        role="user",
        content=("No damage or moisture. It started the day after I received them—October 9."),
    )

    chat_ctx.items.append(
        FunctionCall(
            name="check_warranty",
            call_id="call_check_warranty_1",
            arguments='{"serial":"ACB-PRO-7F29D4","purchase_date":"2025-10-07"}',
        )
    )
    chat_ctx.items.append(
        FunctionCallOutput(
            name="check_warranty",
            call_id="call_check_warranty_1",
            output='{"eligible":true,"warranty_expires":"2026-10-07"}',
            is_error=False,
        )
    )

    chat_ctx.add_message(
        role="assistant",
        content=(
            "This appears to be a hardware defect and you’re under warranty until 2026-10-07. "
            "I can set up a free replacement. Could you confirm your shipping address and a contact number?"
        ),
    )
    chat_ctx.add_message(
        role="user",
        content=("Ship to 2150 Grove St, Apt 4B, Oakland, CA 94612. Phone is (510) 555-0136."),
    )

    chat_ctx.items.append(
        FunctionCall(
            name="create_rma",
            call_id="call_create_rma_1",
            arguments=(
                '{"order_number":"LK-4821936","serial":"ACB-PRO-7F29D4","reason":"left bud crackling / disconnects",'
                '"customer":{"name":"Maya Chen","email":"maya.chen+shop@gmail.com","phone":"(510) 555-0136","'
                'address":"2150 Grove St, Apt 4B, Oakland, CA 94612"}}'
            ),
        )
    )
    chat_ctx.items.append(
        FunctionCallOutput(
            name="create_rma",
            call_id="call_create_rma_1",
            output='{"rma_id":"RMA-90721","replacement_eta_days":2}',
            is_error=False,
        )
    )

    chat_ctx.items.append(AgentHandoff(old_agent_id="SupportGPT", new_agent_id="ReturnsAgent"))

    chat_ctx.items.append(
        FunctionCall(
            name="generate_return_label",
            call_id="call_label_1",
            arguments='{"rma_id":"RMA-90721","email":"maya.chen+shop@gmail.com"}',
        )
    )
    chat_ctx.items.append(
        FunctionCallOutput(
            name="generate_return_label",
            call_id="call_label_1",
            output='{"label_url":"https://example.invalid/label/RMA-90721","due_in_days":14}',
            is_error=False,
        )
    )

    chat_ctx.add_message(
        role="assistant",
        content=(
            "All set! I’ve created RMA #RMA-90721 linked to order LK-4821936. "
            "You’ll receive the prepaid return label and instructions at maya.chen+shop@gmail.com. "
            "Please ship the defective pair within 14 days; your replacement will ship within 48 hours."
        ),
    )

    import json

    async with inference.LLM(model="openai/gpt-4.1-mini") as llm:
        summary = await chat_ctx._summarize(llm, keep_last_turns=1)
        print("\n=== Summary ===\n")
        print(json.dumps(summary.to_dict(), indent=2))


# --- summarize unit tests (no credentials required) ---


class _FixedSummaryLLM(FakeLLM):
    """FakeLLM that returns a fixed summary string for any input."""

    def __init__(self, summary: str) -> None:
        super().__init__()
        self._summary = summary

    def chat(
        self,
        *,
        chat_ctx: ChatContext,
        tools: Any = None,
        conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
        parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
        tool_choice: Any = NOT_GIVEN,
        extra_kwargs: Any = NOT_GIVEN,
    ):
        last_msg = chat_ctx.items[-1]
        input_text = last_msg.text_content
        self._fake_response_map[input_text] = FakeLLMResponse(
            input=input_text,
            content=self._summary,
            ttft=0.0,
            duration=0.0,
        )
        return super().chat(
            chat_ctx=chat_ctx,
            tools=tools,
            conn_options=conn_options,
        )


CANNED_SUMMARY = "User asked about earbuds. Agent resolved the issue."


def _build_conversation_ctx() -> ChatContext:
    """Build a ChatContext with system, user/assistant pairs, and interleaved tool calls."""
    from livekit.agents.llm import ChatContext

    ctx = ChatContext()
    ctx.add_message(role="system", content="You are a helpful assistant.")
    ctx.add_message(role="user", content="Hi, my earbuds are broken.")
    ctx.add_message(role="assistant", content="Can you share your order number?")
    ctx.add_message(role="user", content="Order #123.")
    ctx.items.append(FunctionCall(name="lookup_order", call_id="c1", arguments='{"order": "123"}'))
    ctx.items.append(
        FunctionCallOutput(
            name="lookup_order", call_id="c1", output='{"status":"delivered"}', is_error=False
        )
    )
    ctx.add_message(role="assistant", content="Found your order. Let me check warranty.")
    ctx.add_message(role="user", content="Thanks.")
    ctx.add_message(role="assistant", content="You are under warranty.")
    return ctx


@pytest.mark.asyncio
async def test_summarize_head_tail_split_basic():
    from livekit.agents.llm import ChatContext

    ctx = ChatContext()
    ctx.add_message(role="system", content="System prompt.")
    ctx.add_message(role="user", content="msg1")
    ctx.add_message(role="assistant", content="reply1")
    ctx.add_message(role="user", content="msg2")
    ctx.add_message(role="assistant", content="reply2")
    ctx.add_message(role="user", content="msg3")
    ctx.add_message(role="assistant", content="reply3")

    llm = _FixedSummaryLLM(CANNED_SUMMARY)
    result = await ctx._summarize(llm, keep_last_turns=1)

    tail_msgs = [
        it
        for it in result.items
        if it.type == "message"
        and it.role in ("user", "assistant")
        and not it.extra.get("is_summary")
    ]
    assert len(tail_msgs) == 2
    assert tail_msgs[0].text_content == "msg3"
    assert tail_msgs[1].text_content == "reply3"

    summaries = [it for it in result.items if it.type == "message" and it.extra.get("is_summary")]
    assert len(summaries) == 1
    assert CANNED_SUMMARY in summaries[0].text_content

    system_msgs = [it for it in result.items if it.type == "message" and it.role == "system"]
    assert len(system_msgs) == 1


@pytest.mark.asyncio
async def test_summarize_head_tail_split_with_renderables():
    ctx = _build_conversation_ctx()

    llm = _FixedSummaryLLM(CANNED_SUMMARY)
    result = await ctx._summarize(llm, keep_last_turns=2)

    # With keep_last_turns=2, the backward walk counts 4 ChatMessages:
    #   "You are under warranty." (1), "Thanks." (2),
    #   "Found your order..." (3), "Order #123." (4) ← split here
    # The FunctionCall + FunctionCallOutput between "Order #123." and
    # "Found your order..." fall inside the tail and must be preserved.
    tail_msgs = [
        it
        for it in result.items
        if it.type == "message"
        and it.role in ("user", "assistant")
        and not it.extra.get("is_summary")
    ]
    assert len(tail_msgs) == 4
    assert tail_msgs[0].text_content == "Order #123."
    assert tail_msgs[1].text_content == "Found your order. Let me check warranty."
    assert tail_msgs[2].text_content == "Thanks."
    assert tail_msgs[3].text_content == "You are under warranty."

    fn_items = [it for it in result.items if it.type in ("function_call", "function_call_output")]
    assert len(fn_items) == 2

    summaries = [it for it in result.items if it.type == "message" and it.extra.get("is_summary")]
    assert len(summaries) == 1


@pytest.mark.asyncio
async def test_summarize_keep_last_turns_zero():
    ctx = _build_conversation_ctx()

    llm = _FixedSummaryLLM(CANNED_SUMMARY)
    result = await ctx._summarize(llm, keep_last_turns=0)

    raw_msgs = [
        it
        for it in result.items
        if it.type == "message"
        and it.role in ("user", "assistant")
        and not it.extra.get("is_summary")
    ]
    assert len(raw_msgs) == 0

    fn_items = [it for it in result.items if it.type in ("function_call", "function_call_output")]
    assert len(fn_items) == 0

    summaries = [it for it in result.items if it.type == "message" and it.extra.get("is_summary")]
    assert len(summaries) == 1

    system_msgs = [it for it in result.items if it.type == "message" and it.role == "system"]
    assert len(system_msgs) == 1


@pytest.mark.asyncio
async def test_summarize_preserves_structural_items():
    from livekit.agents.llm import ChatContext

    ctx = ChatContext()
    ctx.add_message(role="system", content="System prompt.")
    ctx.add_message(role="user", content="Hello.")
    ctx.add_message(role="assistant", content="Hi there.")
    ctx.items.append(AgentHandoff(old_agent_id="AgentA", new_agent_id="AgentB"))
    ctx.add_message(role="user", content="Transfer me.")
    ctx.add_message(role="assistant", content="Done.")
    ctx.add_message(role="user", content="Thanks.")
    ctx.add_message(role="assistant", content="Welcome.")

    llm = _FixedSummaryLLM(CANNED_SUMMARY)
    result = await ctx._summarize(llm, keep_last_turns=1)

    # system message preserved
    system_msgs = [it for it in result.items if it.type == "message" and it.role == "system"]
    assert len(system_msgs) == 1

    # agent handoff preserved
    handoffs = [it for it in result.items if it.type == "agent_handoff"]
    assert len(handoffs) == 1
    assert handoffs[0].old_agent_id == "AgentA"
    assert handoffs[0].new_agent_id == "AgentB"


@pytest.mark.asyncio
async def test_summarize_skips_when_not_enough_messages():
    from livekit.agents.llm import ChatContext

    ctx = ChatContext()
    ctx.add_message(role="system", content="System prompt.")
    ctx.add_message(role="user", content="Hello.")
    ctx.add_message(role="assistant", content="Hi there.")

    original_items = list(ctx.items)

    llm = _FixedSummaryLLM(CANNED_SUMMARY)
    result = await ctx._summarize(llm, keep_last_turns=1)

    # budget covers all messages, so nothing to summarize — early return
    assert len(result.items) == len(original_items)
    for a, b in zip(result.items, original_items, strict=True):
        assert a.id == b.id


# --- truncate tests ---


def _make_ctx(*roles: str):
    """Build a ChatContext with messages of the given roles."""
    from livekit.agents.llm import ChatContext

    ctx = ChatContext()
    for role in roles:
        if role == "function_call":
            ctx.items.append(FunctionCall(name="fn", call_id="c1", arguments="{}"))
        elif role == "function_call_output":
            ctx.items.append(
                FunctionCallOutput(name="fn", call_id="c1", output="{}", is_error=False)
            )
        else:
            ctx.add_message(role=role, content=f"msg-{role}")
    return ctx


def test_truncate_noop_when_under_limit():
    ctx = _make_ctx("system", "user", "assistant")
    original_ids = [item.id for item in ctx.items]
    ctx.truncate(max_items=5)
    assert [item.id for item in ctx.items] == original_ids


def test_truncate_basic():
    ctx = _make_ctx("user", "assistant", "user", "assistant")
    ctx.truncate(max_items=2)
    assert len(ctx.items) == 2
    assert ctx.items[0].role == "user"
    assert ctx.items[1].role == "assistant"


def test_truncate_preserves_system_instruction():
    ctx = _make_ctx("system", "user", "assistant", "user", "assistant")
    ctx.truncate(max_items=2)
    # system should be re-inserted at the front
    assert ctx.items[0].role == "system"
    assert len(ctx.items) == 3  # system + last 2


def test_truncate_preserves_developer_instruction():
    ctx = _make_ctx("developer", "user", "assistant", "user", "assistant")
    ctx.truncate(max_items=2)
    assert ctx.items[0].role == "developer"
    assert len(ctx.items) == 3


def test_truncate_no_duplication():
    """When the instruction is already in the truncated tail, don't insert it again."""
    ctx = _make_ctx("system", "user", "assistant")
    ctx.truncate(max_items=3)
    # system is already within the last 3 items, so no duplication
    system_items = [item for item in ctx.items if getattr(item, "role", None) == "system"]
    assert len(system_items) == 1
    assert len(ctx.items) <= 3


def test_truncate_multiple_instructions():
    """Only the first instruction by position is preserved."""
    from livekit.agents.llm import ChatContext

    ctx = ChatContext()
    ctx.add_message(role="system", content="first")
    ctx.add_message(role="developer", content="second")
    ctx.add_message(role="user", content="u1")
    ctx.add_message(role="user", content="u2")
    ctx.add_message(role="user", content="u3")

    ctx.truncate(max_items=2)
    # first instruction is the system msg
    assert ctx.items[0].role == "system"
    assert ctx.items[0].content == ["first"]


def test_instructions_serialization():
    """Instructions must survive Pydantic validation, to_dict, and from_dict round-trips."""
    from livekit.agents.beta import Instructions
    from livekit.agents.llm import ChatContext, ChatMessage

    # Pydantic preserves Instructions type
    instr = Instructions("audio variant", text="text variant")
    msg = ChatMessage(role="system", content=[instr])
    assert isinstance(msg.content[0], Instructions)
    assert msg.content[0].text == "text variant"

    # to_dict serializes both variants as a dict
    ctx = ChatContext([ChatMessage(role="system", content=[instr])])
    data = ctx.to_dict()
    serialized = data["items"][0]["content"][0]
    assert isinstance(serialized, dict)
    assert serialized["audio"] == "audio variant"
    assert serialized["text"] == "text variant"

    # from_dict reconstructs Instructions
    restored = ChatContext.from_dict(ctx.to_dict())
    restored_content = restored.items[0].content[0]
    assert isinstance(restored_content, Instructions)
    assert restored_content.audio == "audio variant"
    assert restored_content.text == "text variant"

    # Plain str content stays as str after round-trip
    plain_ctx = ChatContext([ChatMessage(role="user", content=["hello"])])
    plain_restored = ChatContext.from_dict(plain_ctx.to_dict())
    assert type(plain_restored.items[0].content[0]) is str

    # Instructions without text variant round-trips (falls back to audio)
    audio_only = Instructions("audio only")
    audio_ctx = ChatContext([ChatMessage(role="system", content=[audio_only])])
    audio_restored = ChatContext.from_dict(audio_ctx.to_dict())
    audio_content = audio_restored.items[0].content[0]
    assert isinstance(audio_content, Instructions)
    assert audio_content.audio == "audio only"
    assert audio_content.text == "audio only"


def test_instructions_string_operations():
    """Instructions supports + and r+ operations, propagating both variants."""
    from livekit.agents.beta import Instructions

    # Instructions + Instructions
    a = Instructions("audio A", text="text A")
    b = Instructions("audio B", text="text B")
    result = a + b
    assert isinstance(result, Instructions)
    assert result.audio == "audio Aaudio B"
    assert result.text == "text Atext B"

    # Instructions + str
    instr = Instructions("audio", text="text")
    result = instr + " suffix"
    assert result.audio == "audio suffix"
    assert result.text == "text suffix"

    # str + Instructions (radd)
    result = "prefix " + instr
    assert result.audio == "prefix audio"
    assert result.text == "prefix text"

    # Adding to Instructions without text variant keeps text=None
    audio_only = Instructions("audio only")
    result = audio_only + " more"
    assert result._text_variant is None
    assert result.audio == "audio only more"

    # One side has text variant, other doesn't
    a = Instructions("audio A", text="text A")
    b = Instructions("audio B")
    result = a + " " + b
    assert result.audio == "audio A audio B"
    assert result.text == "text A audio B"


def test_instructions_as_modality():
    """as_modality() bakes the correct variant into str() while preserving both variants."""
    from livekit.agents.beta import Instructions
    from livekit.agents.llm import ChatContext, ChatMessage
    from livekit.agents.voice.generation import INSTRUCTIONS_MESSAGE_ID, apply_instructions_modality

    instr = Instructions("audio instructions", text="text instructions")

    # as_modality('audio')
    resolved = instr.as_modality("audio")
    assert str(resolved) == "audio instructions"
    assert resolved.audio == "audio instructions"
    assert resolved.text == "text instructions"

    # as_modality('text')
    resolved = instr.as_modality("text")
    assert str(resolved) == "text instructions"

    # Can switch modality after resolving
    resolved_text = instr.as_modality("text")
    resolved_audio = resolved_text.as_modality("audio")
    assert str(resolved_audio) == "audio instructions"

    # Instructions without text variant returns audio for both modalities
    audio_only = Instructions("audio only")
    assert str(audio_only.as_modality("audio")) == "audio only"
    assert str(audio_only.as_modality("text")) == "audio only"

    # apply_instructions_modality() on ChatContext
    ctx = ChatContext([ChatMessage(id=INSTRUCTIONS_MESSAGE_ID, role="system", content=[instr])])
    apply_instructions_modality(ctx, modality="audio")
    assert str(ctx.items[0].content[0]) == "audio instructions"
    apply_instructions_modality(ctx, modality="text")
    assert str(ctx.items[0].content[0]) == "text instructions"

    # Re-applying after copy
    base_ctx = ChatContext(
        [ChatMessage(id=INSTRUCTIONS_MESSAGE_ID, role="system", content=[instr])]
    )
    turn1_ctx = base_ctx.copy()
    apply_instructions_modality(turn1_ctx, modality="text")
    turn2_ctx = turn1_ctx.copy()
    apply_instructions_modality(turn2_ctx, modality="audio")
    assert str(turn2_ctx.items[0].content[0]) == "audio instructions"


def test_responses_assistant_phase_round_trip():
    """The OpenAI Responses `phase` captured in message.extra is resent on follow-up requests."""
    from livekit.agents.llm import ChatContext

    ctx = ChatContext.empty()
    ctx.add_message(role="user", content="hello")
    ctx.add_message(
        role="assistant",
        content="thinking out loud",
        extra={"openai": {"phase": "commentary"}},
    )
    ctx.add_message(
        role="assistant",
        content="the answer",
        extra={"openai": {"phase": "final_answer"}},
    )

    items, _ = ctx.to_provider_format(format="openai.responses")
    assistant_items = [item for item in items if item.get("role") == "assistant"]
    assert [item.get("phase") for item in assistant_items] == ["commentary", "final_answer"]


def test_responses_assistant_phase_absent_when_not_set():
    """Assistant messages without a phase don't get a `phase` key in the responses payload."""
    from livekit.agents.llm import ChatContext

    ctx = ChatContext.empty()
    ctx.add_message(role="user", content="hello")
    ctx.add_message(role="assistant", content="hi there")

    items, _ = ctx.to_provider_format(format="openai.responses")
    assistant_items = [item for item in items if item.get("role") == "assistant"]
    assert assistant_items
    assert all("phase" not in item for item in assistant_items)
