# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import json
from collections import namedtuple
import re
from typing import Any, Optional, Union
from pathlib import Path
from packaging.version import parse as parse_version
from jinja2 import PackageLoader, Environment, FileSystemLoader, StrictUndefined

from ... import ReaderAndWriter
from ..models import (
    OperationGroup,
    RequestBuilder,
    OverloadedRequestBuilder,
    CodeModel,
    Client,
    ModelType,
    EnumType,
)
from .enum_serializer import EnumSerializer
from .general_serializer import GeneralSerializer
from .model_init_serializer import ModelInitSerializer
from .model_serializer import DpgModelSerializer, MsrestModelSerializer
from .operations_init_serializer import OperationsInitSerializer
from .operation_groups_serializer import OperationGroupsSerializer
from .request_builders_serializer import RequestBuildersSerializer
from .patch_serializer import PatchSerializer
from .sample_serializer import SampleSerializer
from .test_serializer import TestSerializer, TestGeneralSerializer
from .types_serializer import TypesSerializer
from ...utils import to_snake_case, VALID_PACKAGE_MODE
from .utils import extract_sample_name, get_namespace_from_package_name, get_namespace_config, hash_file_import

_LOGGER = logging.getLogger(__name__)

__all__ = [
    "JinjaSerializer",
]

_PACKAGE_FILES = [
    "CHANGELOG.md.jinja2",
    "dev_requirements.txt.jinja2",
    "LICENSE.jinja2",
    "MANIFEST.in.jinja2",
    "README.md.jinja2",
]

_REGENERATE_FILES = {"MANIFEST.in"}
AsyncInfo = namedtuple("AsyncInfo", ["async_mode", "async_path"])


# extract sub folders. For example, source_file_path is like:
# "xxx/resource-manager/Microsoft.XX/stable/2023-04-01/examples/Compute/createOrUpdate/AKSCompute.json",
# and we want to extract the sub folders after "examples/", which is "compute/create_or_update"
def _sample_output_path(source_file_path: str) -> Path:
    posix_path = Path(source_file_path).as_posix()
    if "examples/" in posix_path:
        after_examples = Path(posix_path.split("examples/", maxsplit=1)[-1]).parent
        return Path("/".join([to_snake_case(i) for i in after_examples.parts]))
    return Path("")


class JinjaSerializer(ReaderAndWriter):
    def __init__(
        self,
        code_model: CodeModel,
        *,
        output_folder: Union[str, Path],
        **kwargs: Any,
    ) -> None:
        super().__init__(output_folder=output_folder, **kwargs)
        self.code_model = code_model
        self._regenerate_setup_py()

    def _regenerate_setup_py(self):
        if self.code_model.options["keep-setup-py"] or self.code_model.options["basic-setup-py"]:
            _PACKAGE_FILES.append("setup.py.jinja2")
            _REGENERATE_FILES.add("setup.py")
        else:
            _PACKAGE_FILES.append("pyproject.toml.jinja2")
            _REGENERATE_FILES.add("pyproject.toml")

    @property
    def has_aio_folder(self) -> bool:
        return not self.code_model.options["no-async"] and bool(self.code_model.has_operations)

    @property
    def has_operations_folder(self) -> bool:
        return self.code_model.options["show-operations"] and bool(self.code_model.has_operations)

    @property
    def serialize_loop(self) -> list[AsyncInfo]:
        sync_loop = AsyncInfo(async_mode=False, async_path="")
        async_loop = AsyncInfo(async_mode=True, async_path="aio/")
        return [sync_loop, async_loop] if self.has_aio_folder else [sync_loop]

    @property
    def keep_version_file(self) -> bool:
        if self.options.get("keep-version-file"):
            return True
        # If the version file is already there and the version is greater than the current version, keep it.
        try:
            serialized_version_file = self.read_file(
                self.code_model.get_generation_dir(self.code_model.namespace) / "_version.py"
            )
            match = re.search(r'VERSION\s*=\s*"([^"]+)"', str(serialized_version_file))
            serialized_version = match.group(1) if match else ""
        except (FileNotFoundError, IndexError):
            serialized_version = ""
        try:
            return parse_version(serialized_version) > parse_version(self.code_model.options.get("package-version", ""))
        except Exception:  # pylint: disable=broad-except
            # If parsing the version fails, we assume the version file is not valid and overwrite.
            return False

    # pylint: disable=too-many-branches
    def serialize(self) -> None:
        # remove existing folders when generate from tsp
        if self.code_model.is_tsp and self.code_model.options.get("clear-output-folder"):
            # remove generated_samples and generated_tests folder
            self.remove_folder(self._generated_tests_samples_folder("generated_samples"))
            self.remove_folder(self._generated_tests_samples_folder("generated_tests"))

            # remove generated sdk files
            generation_path = self.code_model.get_generation_dir(self.code_model.namespace)
            for file in self.list_file_of_folder(generation_path):
                if file.endswith(".py") and "_patch.py" not in file:
                    self.remove_file(file)

        # serialize logic
        env = Environment(
            loader=PackageLoader("pygen.codegen", "templates"),
            keep_trailing_newline=True,
            line_statement_prefix="##",
            line_comment_prefix="###",
            trim_blocks=True,
            lstrip_blocks=True,
        )

        general_serializer = GeneralSerializer(code_model=self.code_model, env=env, async_mode=False)
        for client_namespace, client_namespace_type in self.code_model.client_namespace_types.items():
            generation_path = self.code_model.get_generation_dir(client_namespace)
            if client_namespace == "":
                if self.code_model.options["basic-setup-py"]:
                    # Write the setup file
                    self.write_file(generation_path / Path("setup.py"), general_serializer.serialize_setup_file())
                elif not self.code_model.options["keep-setup-py"]:
                    # remove setup.py file
                    self.remove_file(generation_path / Path("setup.py"))

                # add packaging files in root namespace (e.g. setup.py, README.md, etc.)
                if self.code_model.options.get("package-mode"):
                    self._serialize_and_write_package_files()

                # write apiview-properties.json
                if self.code_model.options.get("emit-cross-language-definition-file"):
                    self.write_file(
                        self._root_of_sdk / Path("apiview-properties.json"),
                        general_serializer.serialize_cross_language_definition_file(),
                    )

                # add generated samples and generated tests
                if self.code_model.options["show-operations"] and self.code_model.has_operations:
                    if self.code_model.options["generate-sample"]:
                        self._serialize_and_write_sample(env)
                    if self.code_model.options["generate-test"]:
                        self._serialize_and_write_test(env)

                # add _metadata.json
                if self.code_model.metadata:
                    self.write_file(
                        self._root_of_sdk / "_metadata.json",
                        json.dumps(self.code_model.metadata, indent=2),
                    )
            elif client_namespace_type.clients:
                # add clients folder if there are clients in this namespace
                self._serialize_client_and_config_files(client_namespace, client_namespace_type.clients, env)
            else:
                # add pkgutil init file if no clients in this namespace
                self.write_file(
                    generation_path / Path("__init__.py"),
                    general_serializer.serialize_pkgutil_init_file(),
                )

            # _utils/py.typed/_types.py/_validation.py
            # is always put in top level namespace
            if self.code_model.is_top_namespace(client_namespace):
                self._serialize_and_write_top_level_folder(env=env, namespace=client_namespace)

            # add models folder if there are models in this namespace
            if (
                self.code_model.has_non_json_models(client_namespace_type.models) or client_namespace_type.enums
            ) and self.code_model.options["models-mode"]:
                self._serialize_and_write_models_folder(
                    env=env,
                    namespace=client_namespace,
                    models=client_namespace_type.models,
                    enums=client_namespace_type.enums,
                )

            if not self.code_model.options["models-mode"]:
                # keep models file if users ended up just writing a models file
                model_path = generation_path / Path("models.py")
                if self.read_file(model_path):
                    self.write_file(model_path, self.read_file(model_path))

            # add operations folder if there are operations in this namespace
            if client_namespace_type.operation_groups:
                self._serialize_and_write_operations_folder(
                    client_namespace_type.operation_groups, env=env, namespace=client_namespace
                )

            # if there are only operations under this namespace, we need to add general __init__.py into `aio` folder
            # to make sure all generated files could be packed into .zip/.whl/.tgz package
            if not client_namespace_type.clients and client_namespace_type.operation_groups and self.has_aio_folder:
                self.write_file(
                    generation_path / Path("aio/__init__.py"),
                    general_serializer.serialize_pkgutil_init_file(),
                )

    # path where README.md is
    @property
    def _root_of_sdk(self) -> Path:
        root_of_sdk = Path(".")
        if self.code_model.options["no-namespace-folders"]:
            compensation = Path("../" * (self.code_model.namespace.count(".") + 1))
            root_of_sdk = root_of_sdk / compensation
        return root_of_sdk

    def _serialize_and_write_package_files(self) -> None:
        root_of_sdk = self._root_of_sdk
        if self.code_model.options["package-mode"] in VALID_PACKAGE_MODE:
            env = Environment(
                loader=PackageLoader("pygen.codegen", "templates/packaging_templates"),
                undefined=StrictUndefined,
                trim_blocks=True,
                lstrip_blocks=True,
            )

            package_files = _PACKAGE_FILES
            if not self.code_model.license_description:
                package_files.remove("LICENSE.jinja2")
        elif Path(self.code_model.options["package-mode"]).exists():
            env = Environment(
                loader=FileSystemLoader(str(Path(self.code_model.options["package-mode"]))),
                keep_trailing_newline=True,
                undefined=StrictUndefined,
            )
            package_files = env.list_templates()
        else:
            return
        serializer = GeneralSerializer(self.code_model, env, async_mode=False)
        params = self.code_model.options.get("packaging-files-config", {})
        for template_name in package_files:
            if not self.code_model.is_azure_flavor and template_name == "dev_requirements.txt.jinja2":
                continue
            file = template_name.replace(".jinja2", "")
            output_file = root_of_sdk / file
            if not self.read_file(output_file) or file in _REGENERATE_FILES:
                if self.keep_version_file and file == "setup.py" and not self.code_model.options["azure-arm"]:
                    # don't regenerate setup.py file if the version file is more up to date for data-plane
                    continue
                file_content = self.read_file(output_file) if file == "pyproject.toml" else ""
                self.write_file(
                    output_file,
                    serializer.serialize_package_file(template_name, file_content, **params),
                )

    def _keep_patch_file(self, path_file: Path, env: Environment):
        if self.read_file(path_file):
            self.write_file(path_file, self.read_file(path_file))
        else:
            self.write_file(
                path_file,
                PatchSerializer(env=env, code_model=self.code_model).serialize(),
            )

    def _serialize_and_write_models_folder(
        self, env: Environment, namespace: str, models: list[ModelType], enums: list[EnumType]
    ) -> None:
        # Write the models folder
        models_path = self.code_model.get_generation_dir(namespace) / "models"
        serializer = DpgModelSerializer if self.code_model.options["models-mode"] == "dpg" else MsrestModelSerializer
        if self.code_model.has_non_json_models(models):
            self.write_file(
                models_path / Path(f"{self.code_model.models_filename}.py"),
                serializer(code_model=self.code_model, env=env, client_namespace=namespace, models=models).serialize(),
            )
        if enums:
            self.write_file(
                models_path / Path(f"{self.code_model.enums_filename}.py"),
                EnumSerializer(
                    code_model=self.code_model, env=env, client_namespace=namespace, enums=enums
                ).serialize(),
            )
        self.write_file(
            models_path / Path("__init__.py"),
            ModelInitSerializer(code_model=self.code_model, env=env, models=models, enums=enums).serialize(),
        )

        self._keep_patch_file(models_path / Path("_patch.py"), env)

    def _serialize_and_write_rest_layer(self, env: Environment, namespace_path: Path) -> None:
        rest_path = namespace_path / Path(self.code_model.rest_layer_name)
        group_names = {rb.group_name for c in self.code_model.clients for rb in c.request_builders}

        for group_name in group_names:
            request_builders = [
                r for c in self.code_model.clients for r in c.request_builders if r.group_name == group_name
            ]
            self._serialize_and_write_single_rest_layer(env, rest_path, request_builders)
        if not "" in group_names:
            self.write_file(
                rest_path / Path("__init__.py"),
                self.code_model.license_header,
            )

    def _serialize_and_write_single_rest_layer(
        self,
        env: Environment,
        rest_path: Path,
        request_builders: list[Union[RequestBuilder, OverloadedRequestBuilder]],
    ) -> None:
        group_name = request_builders[0].group_name
        output_path = rest_path / Path(group_name) if group_name else rest_path
        # write generic request builders file
        self.write_file(
            output_path / Path("_request_builders.py"),
            RequestBuildersSerializer(
                code_model=self.code_model,
                env=env,
                request_builders=request_builders,
            ).serialize_request_builders(),
        )

        # write rest init file
        self.write_file(
            output_path / Path("__init__.py"),
            RequestBuildersSerializer(
                code_model=self.code_model,
                env=env,
                request_builders=request_builders,
            ).serialize_init(),
        )

    def _serialize_and_write_operations_folder(
        self, operation_groups: list[OperationGroup], env: Environment, namespace: str
    ) -> None:
        operations_folder_name = self.code_model.operations_folder_name(namespace)
        generation_path = self.code_model.get_generation_dir(namespace)
        for async_mode, async_path in self.serialize_loop:
            prefix_path = f"{async_path}{operations_folder_name}"
            # write init file
            operations_init_serializer = OperationsInitSerializer(
                code_model=self.code_model, operation_groups=operation_groups, env=env, async_mode=async_mode
            )
            self.write_file(
                generation_path / Path(f"{prefix_path}/__init__.py"),
                operations_init_serializer.serialize(),
            )

            # write operations file
            OgLoop = namedtuple("OgLoop", ["operation_groups", "filename"])
            if self.code_model.options["combine-operation-files"]:
                loops = [OgLoop(operation_groups, "_operations")]
            else:
                loops = [OgLoop([og], og.filename) for og in operation_groups]
            for ogs, filename in loops:
                operation_group_serializer = OperationGroupsSerializer(
                    code_model=self.code_model,
                    operation_groups=ogs,
                    env=env,
                    async_mode=async_mode,
                    client_namespace=namespace,
                )
                self.write_file(
                    generation_path / Path(f"{prefix_path}/{filename}.py"),
                    operation_group_serializer.serialize(),
                )

            # if there was a patch file before, we keep it
            self._keep_patch_file(generation_path / Path(f"{prefix_path}/_patch.py"), env)

    def _serialize_and_write_version_file(
        self,
        general_serializer: GeneralSerializer,
        namespace: Optional[str] = None,
    ):
        if namespace:
            generation_path = self.code_model.get_generation_dir(namespace)
        else:
            generation_path = self.code_model.get_root_dir()

        def _read_version_file(original_version_file_name: str) -> str:
            return self.read_file(generation_path / original_version_file_name)

        def _write_version_file(original_version_file_name: str) -> None:
            self.write_file(
                generation_path / Path("_version.py"),
                _read_version_file(original_version_file_name),
            )

        if self.keep_version_file and _read_version_file("_version.py"):
            _write_version_file(original_version_file_name="_version.py")
        elif self.keep_version_file and _read_version_file("version.py"):
            _write_version_file(original_version_file_name="version.py")
        elif self.code_model.options.get("package-version"):
            self.write_file(
                generation_path / Path("_version.py"),
                general_serializer.serialize_version_file(),
            )

    def _serialize_client_and_config_files(
        self,
        namespace: str,
        clients: list[Client],
        env: Environment,
    ) -> None:
        generation_path = self.code_model.get_generation_dir(namespace)
        for async_mode, async_path in self.serialize_loop:
            general_serializer = GeneralSerializer(
                code_model=self.code_model, env=env, async_mode=async_mode, client_namespace=namespace
            )
            # when there is client.py, there must be __init__.py
            self.write_file(
                generation_path / Path(f"{async_path}__init__.py"),
                general_serializer.serialize_init_file([c for c in clients if c.has_operations]),
            )

            # if there was a patch file before, we keep it
            self._keep_patch_file(generation_path / Path(f"{async_path}_patch.py"), env)

            if self.code_model.clients_has_operations(clients):

                # write client file
                self.write_file(
                    generation_path / Path(f"{async_path}{self.code_model.client_filename}.py"),
                    general_serializer.serialize_service_client_file(clients),
                )

                # write config file
                self.write_file(
                    generation_path / Path(f"{async_path}_configuration.py"),
                    general_serializer.serialize_config_file(clients),
                )

                # sometimes we need define additional Mixin class for client in _utils.py
                self._serialize_and_write_utils_folder(env, namespace)

    def _serialize_and_write_utils_folder(self, env: Environment, namespace: str):
        generation_dir = self.code_model.get_generation_dir(namespace)
        general_serializer = GeneralSerializer(code_model=self.code_model, env=env, async_mode=False)
        utils_folder_path = generation_dir / Path("_utils")
        if self.code_model.need_utils_folder(async_mode=False, client_namespace=self.code_model.namespace):
            self.write_file(
                utils_folder_path / Path("__init__.py"),
                self.code_model.license_header,
            )
        if self.code_model.need_utils_utils(async_mode=False, client_namespace=self.code_model.namespace):
            self.write_file(
                utils_folder_path / Path("utils.py"),
                general_serializer.need_utils_utils_file(),
            )
        # write _utils/serialization.py
        if self.code_model.need_utils_serialization:
            self.write_file(
                utils_folder_path / Path("serialization.py"),
                general_serializer.serialize_serialization_file(),
            )

        # write _model_base.py
        if self.code_model.options["models-mode"] == "dpg":
            self.write_file(
                utils_folder_path / Path("model_base.py"),
                general_serializer.serialize_model_base_file(),
            )

    def _serialize_and_write_top_level_folder(self, env: Environment, namespace: str) -> None:
        root_dir = self.code_model.get_root_dir()
        # write _utils folder
        self._serialize_and_write_utils_folder(env, self.code_model.namespace)

        general_serializer = GeneralSerializer(code_model=self.code_model, env=env, async_mode=False)

        # write _version.py
        self._serialize_and_write_version_file(general_serializer)
        # if there's a subdir, we need to write another version file in the subdir
        if self.code_model.options.get("generation-subdir"):
            self._serialize_and_write_version_file(general_serializer, namespace)

        # write the empty py.typed file
        pytyped_value = "# Marker file for PEP 561."
        self.write_file(root_dir / Path("py.typed"), pytyped_value)

        # write _validation.py
        if any(og for client in self.code_model.clients for og in client.operation_groups if og.need_validation):
            self.write_file(
                root_dir / Path("_validation.py"),
                general_serializer.serialize_validation_file(),
            )

        # write _types.py
        if self.code_model.named_unions:
            self.write_file(
                root_dir / Path("_types.py"),
                TypesSerializer(code_model=self.code_model, env=env).serialize(),
            )

    # pylint: disable=line-too-long
    @property
    def sample_additional_folder(self) -> Path:
        # For special package, we need to additional folder when generate samples.
        # For example, azure-mgmt-resource is combined by multiple modules, and each module is a package.
        # one of namespace is "azure.mgmt.resource.resources.v2020_01_01", then additional folder is "resources"
        # so that we could avoid conflict when generate samples.
        # python config: https://github.com/Azure/azure-rest-api-specs/blob/main/specification/resources/resource-manager/readme.python.md
        # generated SDK: https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/resources/azure-mgmt-resource/generated_samples
        namespace_config = get_namespace_config(self.code_model.namespace)
        num_of_namespace = namespace_config.count(".") + 1
        num_of_package_namespace = (
            get_namespace_from_package_name(self.code_model.options.get("package-name", "")).count(".") + 1
        )
        if num_of_namespace > num_of_package_namespace:
            return Path("/".join(namespace_config.split(".")[num_of_package_namespace:]))
        return Path("")

    def _generated_tests_samples_folder(self, folder_name: str) -> Path:
        return self._root_of_sdk / folder_name

    def _process_operation_samples(
        self,
        samples: dict,
        env: Environment,
        op_group,
        operation,
        import_sample_cache: dict[tuple[str, str], str],
        out_path: Path,
        sample_additional_folder: Path,
    ) -> None:
        """Process samples for a single operation."""
        for sample_value in samples.values():
            file = sample_value.get("x-ms-original-file", "sample.json")
            file_name = to_snake_case(extract_sample_name(file)) + ".py"
            try:
                sample_ser = SampleSerializer(
                    code_model=self.code_model,
                    env=env,
                    operation_group=op_group,
                    operation=operation,
                    sample=sample_value,
                    file_name=file_name,
                )
                file_import = sample_ser.get_file_import()
                imports_hash_string = hash_file_import(file_import)
                cache_key = (op_group.client.client_namespace, imports_hash_string)
                if cache_key not in import_sample_cache:
                    import_sample_cache[cache_key] = sample_ser.get_imports_from_file_import(file_import)
                sample_ser.imports = import_sample_cache[cache_key]

                content = sample_ser.serialize()
                output_path = out_path / sample_additional_folder / _sample_output_path(file) / file_name
                self.write_file(output_path, content)
            except Exception as e:  # pylint: disable=broad-except
                _LOGGER.error("error happens in sample %s: %s", file, e)

    def _serialize_and_write_sample(self, env: Environment):
        out_path = self._generated_tests_samples_folder("generated_samples")
        sample_additional_folder = self.sample_additional_folder

        # Cache import_test per (client_namespace, imports_hash_string) since it's expensive to compute
        import_sample_cache: dict[tuple[str, str], str] = {}

        for client in self.code_model.clients:
            for op_group in client.operation_groups:
                for operation in op_group.operations:
                    samples = operation.yaml_data.get("samples")
                    if not samples or operation.name.startswith("_"):
                        continue
                    self._process_operation_samples(
                        samples,
                        env,
                        op_group,
                        operation,
                        import_sample_cache,
                        out_path,
                        sample_additional_folder,
                    )

    def _serialize_and_write_test(self, env: Environment):
        self.code_model.for_test = True
        out_path = self._generated_tests_samples_folder("generated_tests")

        general_serializer = TestGeneralSerializer(code_model=self.code_model, env=env)
        self.write_file(out_path / "conftest.py", general_serializer.serialize_conftest())

        if not self.code_model.options["azure-arm"]:
            for async_mode in (True, False):
                async_suffix = "_async" if async_mode else ""
                general_serializer.async_mode = async_mode
                self.write_file(
                    out_path / f"testpreparer{async_suffix}.py",
                    general_serializer.serialize_testpreparer(),
                )

        # Generate test files - reuse serializer per operation group, toggle async_mode
        # Cache import_test per (client.name, async_mode) since it's expensive to compute
        import_test_cache: dict[tuple[str, bool], str] = {}
        for client in self.code_model.clients:
            for og in client.operation_groups:
                # Create serializer once per operation group
                test_serializer = TestSerializer(self.code_model, env, client=client, operation_group=og)
                try:
                    for async_mode in (True, False):
                        test_serializer.async_mode = async_mode
                        cache_key = (client.name, async_mode)
                        if cache_key not in import_test_cache:
                            import_test_cache[cache_key] = test_serializer.get_import_test()
                        test_serializer.import_test = import_test_cache[cache_key]
                        content = test_serializer.serialize_test()
                        output_path = out_path / f"{to_snake_case(test_serializer.test_class_name)}.py"
                        self.write_file(output_path, content)
                except Exception as e:  # pylint: disable=broad-except
                    _LOGGER.error("error happens in test generation for operation group %s: %s", og.class_name, e)

        self.code_model.for_test = False
