import contextlib
import time
from functools import wraps

from aws_lambda_tools.common.response import Response
from aws_xray_sdk.core import xray_recorder
from merchant_wallet_libs.logging.globals import get_current_ip_address

from conio_sdk.common import exceptions
from conio_sdk.common.exceptions_mappings import get_response_for_exception
from conio_sdk.logging.factory import LOGGING_FACTORY, EVENTS_FACTORY
from etc import settings


def manage_errors(fun):
    def _inner(*a, **kw):
        try:
            return fun(*a, **kw)
        except exceptions.BFFException as e:
            LOGGING_FACTORY.request.warning(
                'Application error running %s(%s, %s): %s', fun, a, kw, e,
                exc_info=True)
            return get_response_for_exception(e)
        except Exception as e:
            LOGGING_FACTORY.request.exception('Error running %s(%s, %s): %s', fun, a, kw, e)
            return Response(status_code=500)
    return _inner


def log_this(fun):
    @wraps(fun)
    def wrapper(event, *args, **kwargs):
        LOGGING_FACTORY.request.info('Invoking %s', event)
        start = time.time()
        try:
            res = fun(event, *args, **kwargs)
            LOGGING_FACTORY.request.info('Response %s', res)
        except Exception:
            LOGGING_FACTORY.request.exception('Error processing %s', event)
            raise
        delta = time.time() - start
        EVENTS_FACTORY.timing.log(
            {
                'duration': delta,
                'status_code': res['statusCode'],
                'url': event['path'],
                'client_ip': get_current_ip_address()
            }
        )
        return res
    return wrapper


if settings.AWS_TRACING:
    def xray_this(fun):
        @wraps(fun)
        def wrapper(*args, **kwargs):
            with xray_recorder.in_subsegment(fun.__name__):
                res = fun(*args, **kwargs)
            return res
        return wrapper

    def xray_ctx(subsegment_name: str):
        return xray_recorder.in_subsegment(subsegment_name)

else:
    def xray_this(fun):
        return fun

    @contextlib.contextmanager
    def xray_ctx(subsegment_name: str):
        yield
