# -*- coding: utf-8 -*-
from __future__ import print_function
import sys
from functools import wraps
import re
from datetime import datetime
from textwrap import dedent

from datacloud.dev_utils.id_value.id_value_lib import count_md5

from datacloud.input_pipeline.input_checker.xprod_csv_validator import (
    strip_val, check_one_of_required_fields, check_email, check_phone
)
from datacloud.input_pipeline.input_checker.constants import (
    allowed_headers, id_fields, DEFAULT_PHONE_LENGTH
)
from datacloud.dev_utils.id_value.id_value_lib import normalize_email, normalize_phone
from datacloud.input_pipeline.input_pipeline.helpers import check_target_prefix
from datacloud.dev_utils.validators.csv_validator import RecordError
from datacloud.input_pipeline.normalizer.constants import MULTIPLE_VALUES_DELIMITERS


def query_yes_no(question, default="yes"):
    """Ask a yes/no question via raw_input() and return their answer.

    "question" is a string that is presented to the user.
    "default" is the presumed answer if the user just hits <Enter>.
        It must be "yes" (the default), "no" or None (meaning
        an answer is required of the user).

    The "answer" return value is True for "yes" or False for "no".
    """
    valid = {"yes": True, "y": True, "ye": True,
             "no": False, "n": False}
    if default is None:
        prompt = " [y/n] "
    elif default == "yes":
        prompt = " [Y/n] "
    elif default == "no":
        prompt = " [y/N] "
    else:
        raise ValueError("invalid default answer: '%s'" % default)

    while True:
        sys.stdout.write(question + prompt)
        choice = raw_input().lower()
        if default is not None and choice == '':
            return valid[default]
        elif choice in valid:
            return valid[choice]
        else:
            sys.stdout.write("Please respond with 'yes' or 'no' "
                             "(or 'y' or 'n').\n")


class XProdCsvNormalizer(object):
    def __init__(
            self,
            field_names=None,
            normalize_if_exists=False,
            hard_kill=False,
            max_dups_to_print=1,
            max_empty_ids_to_print=1):
        """
        Instantiate a `CSVValidator`, supplying expected `field_names` as a
        sequence of strings.
        """

        self._field_names = field_names or []
        self._value_normalizers = []
        self._target_normalizers = []
        self._record_normalizers = []
        self._rows_set = set()
        self._extid_retrodate_set = set()
        self._duplicates_counter = 0
        self._empty_ids_counter = 0
        self._dup_contracts_counter = 0
        self._normalize_if_exists = normalize_if_exists
        self._hard_kill = hard_kill
        self._max_dups_to_print = max_dups_to_print
        self._max_empty_ids_to_print = max_empty_ids_to_print
        self._target_fields = filter(check_target_prefix, self._field_names)

    def add_value_normalizer(self, field_name, value_normalizer):
        """
        Add a value normalize for the specified field.
        Arguments
        ---------
        `field_name` - the name of the field to attach the value normalize function
        to
        `value_normalizer` - a function that accepts a single argument (a value) and
        returns normalized value
        """

        assert callable(value_normalizer), 'value normalizer must be a callable function'

        t = field_name, value_normalizer
        self._value_normalizers.append(t)

    def add_record_normalizer(self, record_normalizer):
        """
        Add a record normalize for the specified field.
        Arguments
        ---------
        `value_normalizer` - a function that accepts a single argument (a record) and
        returns normalized record
        """

        assert callable(record_normalizer), 'value normalizer must be a callable function'

        t = record_normalizer
        self._record_normalizers.append(t)

    def add_target_normalizer(self, value_normalizer):
        """
        Add a value normalize for target field.
        Arguments
        ---------
        `value_normalizer` - a function that accepts a single argument (a value) and
        returns normalized value
        """

        assert callable(value_normalizer), 'value normalizer must be a callable function'

        self._target_normalizers.append(value_normalizer)

    def _check_value_normalizers_guard_conditions(self):
        """
        Checks if all vlue normalizers use only fields from self._field_names
        If self._normalize_if_exists => bad normalizers would be deleted
        """

        value_normalizers = []
        for field_name, value_normalizer in self._value_normalizers:
            if self._normalize_if_exists:
                if field_name in self._field_names:
                    t = field_name, value_normalizer
                    value_normalizers.append(t)
            else:
                assert field_name in self._field_names, 'unexpected field name: %s' % field_name

        if self._normalize_if_exists:
            self._value_normalizers = value_normalizers

    def _apply_value_normalizers(self, r):
        for field_name, value_normalizer in self._value_normalizers:
            fi = self._field_names.index(field_name)
            if fi < len(r):  # only apply normalize if there is a value
                r[fi] = value_normalizer(r[fi])
        return r

    def _apply_record_normalizers(self, r_dict):
        for record_normalizer in self._record_normalizers:
            r_dict = record_normalizer(r_dict)

        r = [r_dict.get(field, '') for field in self._field_names]
        return r

    def _apply_target_normalizers(self, r):
        for field_name in self._target_fields:
            for value_normalizer in self._target_normalizers:
                fi = self._field_names.index(field_name)
                if fi < len(r):  # only apply normalize if there is a value
                    r[fi] = value_normalizer(r[fi])
        return r

    def _apply_length_normalizer(self, r):
        assert len(r) <= len(self._field_names), 'record %s is longer than field names' % r

        if len(r) < len(self._field_names):
            # Add empty values to record up to field namse length
            t = r + ['', ] * (len(self._field_names) - len(r))
            return t
        return r

    def _apply_header_normalizer(self, r):
        """
        Console dialog changing headers of input CSV
        """

        if all(h in allowed_headers for h in r):
            return r
        print(dedent("""\
            File headers are
            {0}
        """).format(r))
        if query_yes_no('Do you want to rename them?'):
            print('Allowed headers are:')
            for n, h in enumerate(allowed_headers, 1):
                print('{}.{}'.format(n, h))
            print(dedent("""
                Type number to rename header
                sampleDate => 1
                or you can type the whole name of it
                applicationNumber => external_id

                press ENTER to ignore header
            """))
            t = []
            for header in r:
                print('{} => '.format(header), end='')
                choice = raw_input()
                if choice:
                    try:
                        while not int(choice) in range(1, len(allowed_headers) + 1):
                            print('Index out of range!')
                            choice = raw_input()
                        t.append(allowed_headers[int(choice) - 1])
                    except ValueError:
                        t.append(str(choice))
                else:
                    t.append(header)
            print(dedent("""\
                New headers are
                {0}\
            """).format(t))
            return t
        return r

    def normalize(self, data, expect_header_row=True, ignore_lines=0, limit=0, summary_once=0):
        """
        Normalize `data` and return a list of normalized rows.
        Arguments
        ---------
        `data` - any source of row-oriented data, e.g., as provided by a
        `csv.reader`, or a list of lists of strings, or ...
        `expect_header_row` - does the data contain a header row (i.e., the
        first record is a list of field names)? Defaults to True.
        `ignore_lines` - ignore n lines (rows) at the beginning of the data
        """
        rows = list()
        rows_generator = self.inormalize(data, expect_header_row, ignore_lines)
        for i, r in enumerate(rows_generator):
            if summary_once and (i + 1) % summary_once == 0:
                print('{0} rows normalized'.format(i + 1))
            if not limit or i < limit:
                rows.append(r)
            else:
                return rows
        return rows

    def inormalize(self, data, expect_header_row=True, ignore_lines=0):
        """
        Normalize `data` and return a iterator over normalized rows.
        Arguments
        ---------
        `data` - any source of row-oriented data, e.g., as provided by a
        `csv.reader`, or a list of lists of strings, or ...
        `expect_header_row` - does the data contain a header row (i.e., the
        first record is a list of field names)? Defaults to True.
        `ignore_lines` - ignore n lines (rows) at the beginning of the data
        """
        assert expect_header_row or self._field_names, 'expected field_names or CSV headers'

        for i, r in enumerate(data):
            if expect_header_row and i == ignore_lines:
                # r is the header row
                h = self._apply_header_normalizer(r)
                self._field_names = tuple(h)
                self._target_fields = filter(check_target_prefix, self._field_names)
                self._check_value_normalizers_guard_conditions()
                yield h
            elif i >= ignore_lines:
                # r is a data row
                r = self._apply_length_normalizer(r)
                r = self._apply_value_normalizers(r)
                r = self._apply_target_normalizers(r)

                r_dict = self._as_dict(r)
                r = self._apply_record_normalizers(r_dict)
                try:
                    check_one_of_required_fields(id_fields)(r_dict)
                    if not tuple(r) in self._rows_set:
                        self._rows_set.add(tuple(r))

                        retro_date_ext_id = (r_dict['external_id'], r_dict['retro_date'])
                        if retro_date_ext_id not in self._extid_retrodate_set or not self._hard_kill:
                            self._extid_retrodate_set.add(retro_date_ext_id)
                            yield r
                        else:
                            self._dup_contracts_counter += 1
                    else:
                        if self._max_dups_to_print > 0:
                            print('Record is duplicate')
                            print(r)
                            self._max_dups_to_print -= 1

                        self._duplicates_counter += 1
                except RecordError:
                    if self._max_empty_ids_to_print > 0:
                        print('Record has empty ids')
                        print(r)
                        self._max_empty_ids_to_print -= 1

                    self._empty_ids_counter += 1

    def _as_dict(self, r):
        return {
            name: val
            for name, val in zip(self._field_names, r)
        }

    @property
    def duplicates_counter(self):
        return self._duplicates_counter

    @property
    def dup_contracts_counter(self):
        return self._dup_contracts_counter

    @property
    def empty_ids_counter(self):
        return self._empty_ids_counter


# Useful decorators
def multiple_values_splitter(delimiter):
    def split_decorator(func):
        @wraps(func)
        def func_wrapper(v):
            if type(v) is str:
                v = v.strip().translate(None, '"')
                splited_v = re.split(delimiter, v)
                applyed_v = map(lambda x: func(x) if func(x) is not None else x, splited_v)
                return ','.join(applyed_v)
            else:
                return func(v)
        return func_wrapper
    return split_decorator


# Normalizer fucntions
def normalize_empty_target():
    @strip_val
    def normalizer(v):
        if not v:
            return '-1'
        return v
    return normalizer


def normalize_multi_hash():
    @strip_val
    @multiple_values_splitter(MULTIPLE_VALUES_DELIMITERS)
    def normalizer(v):
        return v.lower()
    return normalizer


def normalize_multi_phone():
    @strip_val
    @multiple_values_splitter(MULTIPLE_VALUES_DELIMITERS)
    def normalizer(v):
        if not v:
            return ''
        return normalize_phone(v)
    return normalizer


def normalize_multi_email():
    @strip_val
    @multiple_values_splitter(MULTIPLE_VALUES_DELIMITERS)
    def normalizer(v):
        if not v:
            return ''
        return normalize_email(v)
    return normalizer


def normalize_date(output_format, input_date_format, hard_kill=False):
    def normalizer(v):
        if not v:
            return ''

        try:
            dt = datetime.strptime(v, input_date_format)
        except ValueError:
            print(dedent("""
                {0} doesn\'t match given format {1}
                erasing it
            """).format(v, input_date_format))
            assert hard_kill, 'Can not erase. Should raise hard kill flag'
            return ''

        if dt.year < 1900:
            print('Strange date {}. Will drop it...'.format(v))
            return ''
        return dt.strftime(output_format)

    return normalizer


def normalize_phone_7():
    @strip_val
    @multiple_values_splitter(MULTIPLE_VALUES_DELIMITERS)
    def normalizer(v):
        if v == '7':
            return ''
        return v
    return normalizer


def normalize_email_dash():
    @strip_val
    @multiple_values_splitter(MULTIPLE_VALUES_DELIMITERS)
    def normalizer(v):
        if v == len(v) * '-':
            return ''
        return v
    return normalizer


def normalize_email_russian_no():
    @strip_val
    @multiple_values_splitter(MULTIPLE_VALUES_DELIMITERS)
    def normalizer(v):
        if v == '\xed\xe5\xf2' or v == '\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd':
            return ''
        return v
    return normalizer


def hard_kill_bad_emails():
    def normalizer(v):
        try:
            check_email()(v)
        except ValueError:
            return ''
        return v
    return normalizer


def hard_kill_bad_phones():
    def normalizer(v):
        try:
            check_phone(DEFAULT_PHONE_LENGTH)(v)
        except ValueError:
            return ''
        return v
    return normalizer


def try_to_int():
    def normalizer(v):
        try:
            return str(int(float(v)))
        except ValueError:
            return v
    return normalizer


def normalize_gender():
    def normalizer(v):
        return v.upper()
    return normalizer


def normalize_md5():
    def normalizer(r):
        if r.get('phone'):
            r['phone_id_value'] = count_md5(r['phone'])
        if r.get('email'):
            r['email_id_value'] = count_md5(r['email'])

        return r
    return normalizer


def normalize_multi_yuid():
    @strip_val
    @multiple_values_splitter(MULTIPLE_VALUES_DELIMITERS)
    def normalizer(v):
        if not v:
            return ''
        return v.strip()
    return normalizer
