# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
from typing import Any, Union
from jinja2 import Environment

from ..models.operation import OperationBase
from .import_serializer import FileImportSerializer
from .base_serializer import BaseSerializer
from ..models import (
    CodeModel,
    KeyCredentialType,
    TokenCredentialType,
    ImportType,
    OperationGroup,
    Parameter,
    BodyParameter,
    FileImport,
)
from .utils import create_fake_value

_LOGGER = logging.getLogger(__name__)


class SampleSerializer(BaseSerializer):
    def __init__(
        self,
        code_model: CodeModel,
        env: Environment,
        operation_group: OperationGroup,
        operation: OperationBase[Any],
        sample: dict[str, Any],
        file_name: str,
    ) -> None:
        super().__init__(code_model, env)
        self.operation_group = operation_group
        self.operation = operation
        self.sample = sample
        self.file_name = file_name
        self.sample_params = sample.get("parameters", {})
        self._imports: str = ""

    @property
    def imports(self) -> str:
        return self._imports

    @imports.setter
    def imports(self, value: str) -> None:
        self._imports = value

    def get_file_import(self) -> FileImport:
        imports = FileImport(self.code_model)
        client = self.operation_group.client
        namespace = client.client_namespace
        imports.add_submodule_import(namespace, client.name, ImportType.LOCAL)
        credential_type = getattr(client.credential, "type", None)
        if isinstance(credential_type, TokenCredentialType):
            imports.add_submodule_import("azure.identity", "DefaultAzureCredential", ImportType.SDKCORE)
        elif isinstance(credential_type, KeyCredentialType):
            imports.add_import("os", ImportType.STDLIB)
            imports.add_submodule_import(
                "credentials",
                "AzureKeyCredential",
                ImportType.SDKCORE,
            )
        for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
            if param.client_default_value is None and not param.optional and param.wire_name in self.sample_params:
                imports.merge(param.type.imports_for_sample())

        return imports

    @staticmethod
    def get_imports_from_file_import(file_import: FileImport) -> str:
        return str(FileImportSerializer(file_import, True))

    def _client_params(self) -> dict[str, Any]:
        # client params
        special_param: dict[str, str] = {}
        credential_type = getattr(self.operation_group.client.credential, "type", None)
        if isinstance(credential_type, TokenCredentialType):
            special_param |= {"credential": "DefaultAzureCredential()"}
        elif isinstance(credential_type, KeyCredentialType):
            special_param |= {"credential": 'AzureKeyCredential(key=os.getenv("AZURE_KEY"))'}

        params = [
            p
            for p in (
                self.operation_group.client.parameters.positional + self.operation_group.client.parameters.keyword_only
            )
            if not p.optional and p.client_default_value is None
        ]
        client_params = {
            p.client_name: special_param.get(
                p.client_name,
                f'"{self.sample_params.get(p.wire_name) or p.client_name.upper()}"',
            )
            for p in params
        }

        return client_params

    @staticmethod
    def handle_param(param: Union[Parameter, BodyParameter], param_value: Any) -> str:
        if isinstance(param_value, str):
            if any(i in param_value for i in '\r\n"'):
                return f'"""{param_value}"""'

        return param.type.serialize_sample_value(param_value)

    # prepare operation parameters
    def _operation_params(self) -> dict[str, Any]:
        operation_params = {}
        for param in self.operation.parameters.positional + self.operation.parameters.keyword_only:
            if not param.optional and not param.client_default_value:
                param_value = self.sample_params.get(param.wire_name)
                if not param_value:
                    operation_params[param.client_name] = create_fake_value(param.type)
                else:
                    operation_params[param.client_name] = self.handle_param(param, param_value)
        return operation_params

    def _operation_group_name(self) -> str:
        if self.operation_group.is_mixin:
            return ""
        return f".{self.operation_group.property_name}"

    def _operation_result(self) -> tuple[str, str]:
        is_response_none = "None" in self.operation.response_type_annotation(async_mode=False)
        lro = ".result()"
        if is_response_none:
            paging, normal_print, return_var = "", "", ""
        else:
            paging = "\n    for item in response:\n        print(item)"
            normal_print = "\n    print(response)"
            return_var = "response = "

        if self.operation.operation_type == "paging":
            return paging, return_var
        if self.operation.operation_type == "lro":
            return lro + normal_print, return_var
        if self.operation.operation_type == "lropaging":
            return lro + paging, return_var
        return normal_print, return_var

    def _operation_name(self) -> str:
        return f".{self.operation.name}"

    def _origin_file(self) -> str:
        name = self.sample.get("x-ms-original-file", "")
        if "specification" in name:
            return "specification" + name.split("specification")[-1]
        return name if self.code_model.options["from-typespec"] else ""

    def serialize(self) -> str:
        operation_result, return_var = self._operation_result()
        return self.env.get_template("sample.py.jinja2").render(
            code_model=self.code_model,
            client=self.operation_group.client,
            file_name=self.file_name,
            operation_result=operation_result,
            operation_params=self._operation_params(),
            operation_group_name=self._operation_group_name(),
            operation_name=self._operation_name(),
            imports=self.imports,
            client_params=self._client_params(),
            origin_file=self._origin_file(),
            return_var=return_var,
        )
