import yaml
import re
import os
import boto3
import sys


def invert_dict_type(d):
    # Create a new dictionary that will hold the inverted key-value pairs
    inverted = {}
    # Loop through all the items in the original dictionary
    for key, value in d.items():
        if isinstance(value, list):
            for ms_dict in value:
                inverted[ms_dict['ms_name']] = key
                inverted[ms_dict['ms_name']] = {"type": ms_dict['type'], "stage": ms_dict['stage']}
        else:
            # Add the inverted key-value pair to the new dictionary
            inverted[value['ms_name']] = {"type": value['type'], "stage": value['stage']}
    # Return the new inverted dictionary
    return inverted


def get_repos():
    repos = {}
    with open("repos.yaml", "r") as stream:
        try:
            repos = yaml.safe_load(stream)
        except yaml.YAMLError as e:
            print(e)
    return repos


def get_pipeline_version(ms_name, env_name):
    try:
        msk = boto3.client('kafkaconnect', region_name=msk_region)
        response = msk.list_connectors(
            connectorNamePrefix=f'{env_name}-{ms_name}'
        )
        if len(response["connectors"]) != 1:
            print(f"error- len of connectors with the prefix {env_name}-{ms_name} not equal to 1")
            return None
        if len(response['connectors'][0]['plugins']) != 1:
            print(f"error- len of plugins of connector {response['connectors'][0]['connectorName']} not equal 1")
            return None
        plugin_arn = response['connectors'][0]['plugins'][0]['customPlugin']['customPluginArn']
        plugin_name = plugin_arn.split("/")[-2]
        # print(plugin_name)
        plugin_version = plugin_name.split(f"-{ms_name}-")[-1].replace("-", ".")
        return plugin_version
    except Exception as e:
        print(f"error in get_pipeline_version- {e}")
        print(response)
        return None


def get_lambda_version(ms_name, env_name):
    try:
        s3_client = boto3.client('s3')
        response = s3_client.list_objects_v2(Bucket=f"p81-{env_name}-env-files", Prefix="zips")
        if 'Contents' in response:
            for obj in response['Contents']:
                file_name = obj['Key']
                if ms_name in file_name:
                    version = file_name.split(ms_name)[1].split(".zip")[0]
                    if version and version[0] == "-":
                        version = version[1:]
                    if re.search("[0-9]+.[0-9]+.[0-9]+.*", version):
                        return version
    except Exception as e:
        print(f"error in get_lambda_version- {e}")
        return None


def get_s3_version_file(ms_name, env_name):
    try:
        s3 = boto3.client('s3', region_name=ecr_region)
        os.system("touch VERSION.txt")
        f = open("VERSION.txt", "wb")
        s3.download_fileobj(f'p81-{env_name}-fe-static-content-{ms_name}', 'VERSION.txt', f)
        f = open("VERSION.txt", "r")
        version = f.read().split('\n')[0]
        os.system("touch VERSION.txt")
        os.system("rm VERSION.txt")
        print(f'frontend {env_name} version for {ms_name} is {version}')
        if re.search("[0-9]+.[0-9]+.[0-9]+.*", version):
            return version
    except Exception as e:
        os.system("touch VERSION.txt")
        os.system("rm VERSION.txt")
        print(f"error in get_s3_version_file for ms {ms_name}- {e}")
        return None


def get_ecr_version(ms_name, env_name):
    try:
        client = boto3.client('ecr', region_name=ecr_region)
        response = client.describe_images(repositoryName=f'{ecr_prefix}{ms_name}',
                                          imageIds=[{'imageTag': env_name}])
        tags = response['imageDetails'][0]["imageTags"]
        for tag in tags:
            if re.search(".*v[0-9]+.[0-9]+.[0-9]+.*", tag):
                print(f'ecr {env_name} version for {ms_name} is {tag}')
                return tag
        return None
    except Exception as e:
        print(f"error in get_ecr_version- {e}")
        return None


def get_env_version(ms_name, env_name):
    version = None
    if not invert_repos.get(ms_name):
        print(f"in get_env_version ms_name: {ms_name} not in repos")
        return version
    ms_type = invert_repos[ms_name]['type']
    if ms_type == 'backend':
        if types == 'all' or ms_type == types:
            version = get_ecr_version(ms_name, env_name)
    elif ms_type == 'core-ecs':
        if types == 'all' or ms_type == types:
            version = get_ecr_version(ms_name, env_name)
    elif ms_type == 'frontend':
        if types == 'all' or ms_type == types:
            version = get_s3_version_file(ms_name, env_name)
    elif ms_type == 'lambda':
        if types == 'all' or ms_type == types:
            version = get_lambda_version(ms_name, env_name)
    elif ms_type == 'pipeline':
        if types == 'all' or ms_type == types:
            version = get_pipeline_version(ms_name, env_name)
    else:
        print(f"in get_env_version ms_name: {ms_name} type : {ms_type}")
        print(f"type not in ms to deploy")
    return version


if __name__ == '__main__':
    msk_region = os.getenv('MSK_REGION')
    ecr_region = os.getenv('ECR_REGION')
    ecr_prefix = os.getenv('DOCKER_PROJECT')
    env_name = sys.argv[1].replace("env_to_get=", "", 1)
    types = sys.argv[2].replace("types=", "", 1)
    repos = get_repos()
    invert_repos = invert_dict_type(repos)
    final_versions = {}
    error_versions = []
    for ms_name in invert_repos:
        env_version = get_env_version(ms_name, env_name)
        if env_version:
            final_versions[ms_name] = env_version
        else:
            error_versions.append(ms_name)
    os.system(f'''echo "INPUTS={str(final_versions)}" >> $GITHUB_ENV''')
    print("######################################")
    print(f"this is the versions I got for env {env_name}")
    for ms in final_versions:
        print(f"{ms}: {final_versions[ms]}")
    print("######################################")
    print(f"this is the versions I DID NOT get for env {env_name}")
    for ms in error_versions:
        print(f"ERROR- {ms}")
