import os
from threading import Thread
from time import sleep
import queue
import boto3
from actions_logging.app_logging import logger
from aws.constants import ENVS, ECS_REACH_STEADY_STATE, PROD_ENVS, STG_ENVS, PRODUCTION, STAGING, \
    EC2_ECS_CLUSTER_SERVICES, STG_PROD_ENVS
from datadog.send_log_metric import format_log_and_send_to_datadog
from env_files.constants import DEFAULT_ENV_FILES_PATH
from github.env import exit_on_error_and_write_summary, get_required_env_var


def get_task_and_cluster_names(ms_name, env_name) -> tuple[str, str]:
    cluster_name_parts = [env_name]
    if ms_name in EC2_ECS_CLUSTER_SERVICES:
        cluster_name_parts.append("ec2")
    cluster_name_parts.append("micro-services")
    cluster_name = "-".join(cluster_name_parts)
    if env_name in PROD_ENVS:
        env_name = PRODUCTION
    elif env_name in STG_ENVS:
        env_name = STAGING
    task_name = f"{ms_name}-family-{env_name}"
    logger.info(f"ecs service cluster name to update: {cluster_name} with task name: {task_name}")
    return task_name, cluster_name


def check_if_service_is_flapping_and_send_log_to_dd(response, ms_name):
    """
    for service in prod or staging env - check if service is flapping and send log to datadog
    :param response:
    :param ms_name:
    :return:
    """
    try:
        logger.debug(f"service {ms_name} in {STG_PROD_ENVS}. will check if flapping and will send log to DD if so")
        failed_tasks = response['services'][0]['deployments'][0]['failedTasks']
        logger.debug(f"failed_tasks for {ms_name}: {failed_tasks}")
        if failed_tasks > 0:
            msg = f"[{ms_name}] the service might be in flapping mode. failed_tasks: {failed_tasks}"
            format_log_and_send_to_datadog(msg, 'ERROR', deploy_tag='deploy')
    except Exception as e:
        logger.error(f"error in check_if_service_is_flapping_and_send_log_to_dd for service: {ms_name} - {e}")


def get_describe_service_message(client, ms_name, cluster_name, env_name) -> tuple[str, str]:
    try:
        response = client.describe_services(
            cluster=cluster_name,
            services=[ms_name]
        )
        if env_name in STG_PROD_ENVS:
            check_if_service_is_flapping_and_send_log_to_dd(response, ms_name)
        # response will return one service and the 0 event is the latest event
        message = response['services'][0]['events'][0]['message']
        event_id = response['services'][0]['events'][0]['id']
        logger.debug(f"message for service: {ms_name} in cluster: {cluster_name} - {message}. id: {event_id}")
        return message, event_id
    except Exception as e:
        raise RuntimeError(f"error in get_describe_service_message for service: {ms_name} in cluster: {cluster_name} - {e}")


def update_service(client, ms_name, task_name, cluster_name):
    try:
        client.update_service(
            cluster=cluster_name,
            service=ms_name,
            taskDefinition=task_name,
            forceNewDeployment=True,
            enableExecuteCommand=True
        )
        logger.info_green(f"update start for service: {ms_name} in cluster: {cluster_name}")
    except Exception as e:
        exit_on_error_and_write_summary(f"error in update_service for service: {ms_name} in cluster: {cluster_name} - {e}")


def wait_for_service_to_update(client, ms_name, cluster_name, init_message_id,
                               env_name, timeout_minutes=20, errors_queue=None):
    try:
        loops = int(timeout_minutes * 60 / 10)
        logger.info_green(f"wait for service {ms_name} to update with timeout {timeout_minutes} minutes")
        for _ in range(loops):
            message, event_id = get_describe_service_message(client, ms_name, cluster_name, env_name)
            logger.info_yellow(f"message for service: {ms_name} in cluster: {cluster_name} - {message}")
            if ECS_REACH_STEADY_STATE in message:
                if init_message_id:
                    if init_message_id == event_id:
                        logger.info_green(f"the message '{ECS_REACH_STEADY_STATE}' is still from old event")
                        sleep(10)
                        continue
                logger.debug(f"reach steady state for service: {ms_name} with message id: {event_id}")
                logger.info_green_bg(f"ya-ba-da-ba-du the service {ms_name} has reached a steady state")
                return
            sleep(10)
        raise TimeoutError(f"service {ms_name} in cluster {cluster_name} did not reach steady state in {timeout_minutes} minutes")
    except Exception as e:
        error_msg = f"error in wait_for_service_to_update for service: {ms_name} in cluster: {cluster_name} - {e}"
        if errors_queue:
            errors_queue.put(error_msg)
        else:
            exit_on_error_and_write_summary(error_msg)


def update_one_ecs_service(env_name, ms_name, session_profile="", errors_queue=None):
    try:
        if session_profile:
            my_session = boto3.session.Session(profile_name=session_profile)
        else:
            my_session = boto3.session.Session()
        ecs_region = ENVS[env_name].get("aws_region")
        client = my_session.client('ecs', region_name=ecs_region)
        task_name, cluster_name = get_task_and_cluster_names(ms_name, env_name)
        current_service_message, current_message_id = get_describe_service_message(client, ms_name, cluster_name, env_name)
        old_id = ""
        if ECS_REACH_STEADY_STATE in current_service_message:
            logger.info_green(f"service: {ms_name} in cluster: {cluster_name} has been in steady state")
            old_id = current_message_id
        update_service(client, ms_name, task_name, cluster_name)
        wait_for_service_to_update(client, ms_name, cluster_name, old_id, env_name, errors_queue=errors_queue)
    except Exception as e:
        error_message = f"error in update_one_ecs_service for service: {ms_name} - {e}"
        if errors_queue:
            errors_queue.put(error_message)
        else:
            exit_on_error_and_write_summary(error_message)


def main(session_profile=""):
    try:
        env_name = get_required_env_var("ENV_NAME")
        if env_name not in ENVS:
            exit_on_error_and_write_summary(f"env {env_name} not found in ENVS")
        svc_list = []
        is_core_ecs = os.getenv("SVC_TYPE", "") == 'core-ecs'
        if is_core_ecs:
            if not os.path.exists(DEFAULT_ENV_FILES_PATH):
                exit_on_error_and_write_summary(f"env files path {DEFAULT_ENV_FILES_PATH} not found")
            svc_list = os.listdir(DEFAULT_ENV_FILES_PATH)
        if not svc_list:
            ms_name = get_required_env_var("SVC_NAME")
            update_one_ecs_service(env_name, ms_name, session_profile)
        else:
            logger.info("will update all services in threads")
            errors_queue = queue.Queue()
            t_list = []
            for ms_name in svc_list:
                t = Thread(target=update_one_ecs_service, args=[env_name, ms_name, session_profile, errors_queue])
                t_list.append(t)
                t.start()
            for t in t_list:
                t.join()
            if not errors_queue.empty():
                raise RuntimeError(f"errors queue is not empty: {errors_queue.get()}")
        logger.info_green("all ecs services have been updated successfully")
    except Exception as e:
        exit_on_error_and_write_summary(f"error in update_ecs_service.main: {e}")


if __name__ == '__main__':
    main()
