"""**ppg_detector.py**: module with real-time feature detection
algorithms for Byteflies PPG nodes."""

import logging

import numpy as np
from scipy import signal
from configobj import ConfigObj

from shared.peakfind import peakfind
from feature_detection import Rslt


class PpgDetector():
    """ PPG systole, diastole and HR detector."""
    def __init__(self, results):
        """PpgDetector constructor.

        Arguments:
            results (Queue): queue of results to send to display server
        """
        self.results = results

        self.config = ConfigObj('ppg/ppg.config')

        self.fs = int(self.config['fs'])
        self.win_size = int(self.fs * float(self.config['win_size']))
        self.win_shift = int(self.fs * float(self.config['win_shift']))

        self.data = list()
        self.index = 0
        self.hr_counter = 0
        self.dias_hist, self.inst_hr_hist = list(), list()

    def append(self, sample):
        """Append a new sample to the PPG feature detector data queue.

        When the data queue reaches the win_size, the data queue is analyzed
        and the queue is shifted by win_shift. The analysis includes systole,
        diastole and HR detection.

        Arguments:
            sample (scalar): new data point to add to the data queue
        """
        self.data.append(sample)

        if len(self.data) == self.win_size:
            self.hr_counter += 1
            if self._analyzeChunk() and self.hr_counter >= 4:
                self.hr_counter = 0
                self._calc_rate()

            del self.data[0:self.win_shift]
            self.index += self.win_shift

    def _analyzeChunk(self):
        """ Analyzes a chunk of PPG data.

        The analysis detects systole and diastole and sends the results to the
        results queue.

        Returns:
            result (bool): True if a new feature was detected
        """

        # Normalize chunk
        chunk_norm = self.data - np.mean(self.data)
        std_chunk = np.std(chunk_norm)
        if std_chunk != 0:
            chunk_norm /= std_chunk

        syst, dias, chunk_filt = self._ppg_feature(chunk_norm)

        if syst is not None:
            if not self._check_redundant(dias[0], 0.1):
                self.dias_hist.append(dias[0])
                if len(self.dias_hist) > 2:
                    self.dias_hist.pop(0)
                self.results.put([Rslt.SYST, syst])
                self.results.put([Rslt.DIAS, dias])
                return True
        return False

    def _ppg_feature(self, chunk):
        """Filter the input PPG signal and detect systole and diastole
        features.

        To ensure high true positive detection rates, feed data chunks that
        share sufficient window overlap. The trade-off is a slight increase
        in false positive detection rate.

        Arguments:
            chunk (list): PPG signal

        Returns:
            (tuple): tupple containing:
                - syst (list): index and amplitude of systole
                - dias (list): index and amplitude of diastole
                - chunk_HP (list): input PPG signal after applying highpass
                                   filter
        """
        # TODO check filters!
        HP, LP = 1/(self.fs/2.0), 1.5/(self.fs/2.0)
        # Zero-phase highpass filter PPG signal (baseline drift)
        b, a = signal.butter(4, HP, 'high')
        chunk_HP = signal.filtfilt(b, a, chunk)
        # Zero-phase lowpass filter PPG signal and invert (identify signal
        # cycles)
        b, a = signal.butter(4, LP, 'low')
        chunk_LP = -signal.filtfilt(b, a, chunk_HP)

        # Primary feature detection
        try:
            syst, dias = self._systdias_RT(chunk_HP, chunk_LP)
            # Correct index value of chunk
            syst[0] += self.index
            dias[0] += self.index
        # Catch function returns that failed to generate a result
        except TypeError:
            syst, dias = None, None

        return syst, dias, chunk_HP

    def _systdias_RT(self, x_sample, cycle_sample):
        """PPG signal chunks sent by a real-time processor.

        This is a rule-based algorithm that only requires two derivations of
        the original signal (at low computational cost):
        1) highpass filtered: baseline drift correction
        2) lowpass filtered:rule fails, the algorithm returns Null.

        WARNING!: upsampled data with "staircase" morphology can lead to
        unexpected results; downsample the input data first if necessary.

        Arguments:
            x_sample (list): highpass filtered PPG signal
            cycle_sample (list): lowpass filtered PPG signal

        Returns:
            (tuple): tupple containing:

                - syst (list): index and amplitude of systole
                - dias (list): index and amplitude of diastole
        """
        # Write DEBUG to file
        # logging.basicConfig(filename="systdias_RT.log", level=logging.debug,
        #                     backupCount=100)

        # Initialization
        edge_lag = int(self.fs / 25)  # Window edge artifact exclusion
        nudge = int(self.fs / 20)  # Nudge away from or to next feature
        around = int(self.fs / 5)  # Search area around feature candidate
        syst = np.zeros((1, 2))
        dias = np.zeros((1, 2))

        # Find candidate PPG cycles: identify peak in inverted cycle_sample
        # that delineate the start of a potential PPG cycle (~ diastole)
        max_i1 = peakfind(cycle_sample, 0, i=edge_lag, fail_flag=True)
        # ESCAPE RULE #1: check for presence of initiating PPG cycle
        if max_i1 == -1:
            logging.debug('PPG cycle start boundary missing')
            return

        # Find 2nd peak (~ diastole) to the right side of the previous
        max_i2 = peakfind(cycle_sample, 0, i=max_i1+nudge, fail_flag=True)
        # Find systole surrounded by max_i1 and max_i2
        min_i = peakfind(-cycle_sample, 0, i=max_i1, max_depth=max_i2,
                         fail_flag=True)
        # ESCAPE RULE #2: check cycle boundaries and amplitude of the
        # cycle_sample signal
        if max_i2 == -1 or min_i == -1:
            logging.debug('PPG cycle misconfigured')
            return
        elif (cycle_sample[max_i1] - cycle_sample[min_i] <
              float(self.config['cycle_amp'])):
            logging.debug('PPG cycle amplitude too low')
            return

        # Identify index and amplitude of candidate systole in cycle
        syst_i = np.argmax(x_sample[max_i1:max_i2]) + max_i1
        syst_amp = x_sample[syst_i]
        # Find candidate diastole around maximum in cycle
        dias_i = (np.argmin(x_sample[max_i2-around:max_i2+around]) +
                  max_i2-around)
        dias_amp = x_sample[dias_i]
        # ESCAPE RULE #3: check candidate systolic peak prominence
        if syst_amp - dias_amp < float(self.config['feat_amp']):
            logging.debug('PPG systole-to-diastole prominence too low')
            return

        # ESCAPE RULE #4: cycle time is clearly wrong
        if not (float(self.config['syst_to_dias_low'])*self.fs <=
                dias_i - syst_i <=
                (float(self.config['syst_to_dias_high'])*self.fs)):
            logging.debug('systole-to-diastole timing is out-of-range')
            return
        # ESCAPE RULE #5: positional relation of systole and diastole is wrong
        if syst_amp <= dias_amp or dias_amp >= 0 or syst_amp <= 0:
            logging.debug('relation of systole and diastole amplitude do not \
                          match')
            return

        syst = [syst_i, syst_amp]
        dias = [dias_i, dias_amp]

        return syst, dias

    def _check_redundant(self, feat, tolerance):
        """Check if new feature is redudant.

        The feature detection algorithms have redundancy built-in by design
        to improve feature detection accuracy. This function removes the
        redundancy by assessing the proximity to the previously detected
        feature index. For simultaneous detection of multiple features on
        the same signal.

        Arguments:
            feat (list): new feature indice (time in samples)
            tolerance (scalar): number of samples around index that are
                                considered equal (in sec)

        Returns:
            redundant (bool): True if new feature is redundant
        """
        # Check index distance between last and second-to-last element
        if len(self.dias_hist) > 0:
            master_diff = feat - self.dias_hist[-1]

            # Set tolerance and return redudant if below threshold
            tol = tolerance * self.fs
            if master_diff <= tol:
                return True
        return False

    def _calc_rate(self):
        """Calculate the instantaneous and average heart.

        The average heart rate is calculated over 10 sec. Each time a new heart
        rate is calculated it is sent to the results queue.
        """
        # Compute instantaneous HR
        inst_HR, avg_HR = None, None
        if len(self.dias_hist) == 2:
            IBI = self.dias_hist[-1] - self.dias_hist[-2]  # Interbeat interval
            inst_HR = (self.fs / IBI) * 60  # Convert to BPM
            self.inst_hr_hist.append([self.dias_hist[-1], inst_HR])

        # Remove older than 10s instantaneous HR
        if len(self.inst_hr_hist) > 0:
            while (self.inst_hr_hist[-1][0] - self.inst_hr_hist[0][0] >
                   10 * self.fs):
                self.inst_hr_hist.pop(0)

        if len(self.inst_hr_hist) > 0:
            avg_HR = np.mean([x[1] for x in self.inst_hr_hist])

        if inst_HR is not None:
            self.results.put([Rslt.INST_HR, [self.dias_hist[-1], inst_HR]])
            self.results.put([Rslt.AVG_HR, [self.dias_hist[-1], avg_HR]])
