# -*- coding: utf-8 -*-
from __future__ import print_function
import calendar
import yt.wrapper as yt
from datetime import datetime, date
from textwrap import dedent
from collections import namedtuple

from datacloud.dev_utils.id_value.id_value_lib import count_md5
from datacloud.input_pipeline.input_checker.constants import MULTIPLE_VALUES_DELIMITER
from datacloud.dev_utils.time.patterns import FMT_DATE, FMT_DATE_YM


def map_csv_to_table(headers):
    def map_row(r):
        res = {}
        for h, v in zip(headers, r):
            if v == '':
                v = None
            res[h] = v
        return res
    return map_row


def field_id(field_name, headers):
    try:
        return headers.index(field_name)
    except ValueError:
        return None


def get_id_field_values(field_name, field_id, field_value_id, row):
    res = []
    if field_id is not None and field_id < len(row):
        for f in row[field_id].split(MULTIPLE_VALUES_DELIMITER):
            if f:
                res.append({
                    field_name: f,
                    'id_value': count_md5(f)
                })

    if field_value_id is not None and field_value_id < len(row):
        for f in row[field_value_id].split(MULTIPLE_VALUES_DELIMITER):
            if f:
                res.append({
                    'id_value': f
                })

    return res


def get_audience_path(partner_id, ticket_name, settings):
    if settings.AUDIENCE_CUSTOM_BASE_ROOT:
        return '{0}/{1}/{2}'.format(
            settings.AUDIENCE_CUSTOM_BASE_ROOT,
            ticket_name,
            partner_id
        )
    return settings.DEFAULT_AUDIENCE_DIR + partner_id


def make_ticket_field_name(ticket_name, ticket_suffix):
    ticket_field_name = 'pipeline_{0}{1}'.format(
        ticket_name,
        ticket_suffix
    )
    return ticket_field_name


IdValueColumn = namedtuple('IdValCol', ['column_name', 'id_type', 'apply_md5'])


class MakeInputMapper:
    def __init__(self, target_columns, date_format):
        self.target_columns = target_columns
        self.date_format = date_format
        self.id_types = (
            IdValueColumn(column_name='email', id_type='email_md5', apply_md5=True),
            IdValueColumn(column_name='email_id_value', id_type='email_md5', apply_md5=False),
            IdValueColumn(column_name='phone', id_type='phone_md5', apply_md5=True),
            IdValueColumn(column_name='phone_id_value', id_type='phone_md5', apply_md5=False),
            IdValueColumn(column_name='yuid', id_type='yuid', apply_md5=False),
        )

    def _split_val(self, val, apply_md5):
        for idr in val.split(MULTIPLE_VALUES_DELIMITER):
            if idr:
                if apply_md5:
                    yield count_md5(idr)
                else:
                    yield idr

    def __call__(self, rec):
        dt = datetime.strptime(rec['retro_date'], self.date_format)
        base_rec = {
            'external_id': rec['external_id'],
            'timestamp': calendar.timegm(dt.utctimetuple())
        }
        for target in self.target_columns:
            base_rec[target] = int(rec[target])

        processed_id_types = set()
        for column_name, id_type, apply_md5 in self.id_types:
            if rec.get(column_name) and id_type not in processed_id_types:
                for id_value in self._split_val(rec[column_name], apply_md5=apply_md5):
                    yield dict(id_type=id_type, id_value=id_value, **base_rec)
                processed_id_types.add(id_type)


class GlueExternalIdMapper:
    def __init__(self, date_format=None):
        self.date_format = date_format

    def __call__(self, rec):
        if self.date_format is None:
            retro_date = rec['retro_date']
        else:
            ts = rec['timestamp']
            retro_date = datetime.utcfromtimestamp(ts).strftime(self.date_format)
        rec['external_id'] += '_{}'.format(retro_date)
        yield rec


@yt.with_context
class FeaturesJoinTargetReducer(object):
    def __init__(self, target_columns):
        self.target_columns = target_columns

    def __call__(self, key, recs, context):
        rec_to_yield = {}
        correct_rec = False
        for rec in recs:
            if context.table_index == 0:
                for target in self.target_columns:
                    rec_to_yield[target] = int(rec[target])
                correct_rec = True
            elif correct_rec:
                rec_to_yield.update({
                    'external_id': rec['external_id'],
                    'features': rec['features'],
                })
                yield rec_to_yield


@yt.with_context
class HistoryTableReducer(object):
    def __init__(self, partner, ticket, target_column=None):
        self.partner = partner
        self.ticket = ticket
        self.target_column = target_column

    def __call__(self, key, recs, context):
        id_rows = []
        for rec in recs:
            if context.table_index == 0:
                id_rows.append({
                    'external_id': rec['external_id'],
                    'id_type': rec['id_type'],
                    'id_value': rec['id_value'],
                    'ticket': self.ticket,
                    'partner': self.partner
                })
            else:
                for id_row in id_rows:
                    target = rec.get(self.target_column, -1)
                    yield dict(
                        target=int(target),
                        retro_date=rec['retro_date'],
                        upper_bound_date=datetime.now().strftime(FMT_DATE),
                        **id_row
                    )


def sub_months(sourcedate, months):
    month = sourcedate.month - 1 - months
    year = int(sourcedate.year + month / 12)
    month = month % 12 + 1
    day = min(sourcedate.day, calendar.monthrange(year, month)[1])
    return date(year, month, day)


def add_months(sourcedate, months):
    month = sourcedate.month - 1 + months
    year = int(sourcedate.year + month / 12)
    month = month % 12 + 1
    day = min(sourcedate.day, calendar.monthrange(year, month)[1])
    return date(year, month, day)


def merge_two_dicts(a, b, path=None):
    "merges b into a"
    if path is None:
        path = []
    for key in b:
        if key in a:
            if isinstance(a[key], dict) and isinstance(b[key], dict):
                merge_two_dicts(a[key], b[key], path + [str(key)])
            elif a[key] == b[key]:
                pass  # same leaf value
            else:
                a[key] = b[key]
                # raise Exception('Conflict at %s' % '.'.join(path + [str(key)]))
        else:
            a[key] = b[key]
    return a


def date2month_mapper(rec):
    dt = datetime.strptime(rec['retro_date'], FMT_DATE)
    rec['retro_date'] = dt.strftime(FMT_DATE_YM)
    yield rec


def count_in_month_reducer(key, recs):
    yield {
        'month': key['retro_date'],
        'count': sum(1 for x in recs)
    }


def check_target_prefix(s):
    return s.startswith('target')


join_cids_query = dedent("""
    pragma yt.ForceInferSchema = '1';

    $crypta = (
        SELECT
            "phone_md5" as id_type,
            id_value,
            cid
        FROM `%(phone2cid)s` as p
        UNION ALL
        SELECT
            "email_md5" as id_type,
            id_value,
            cid
        FROM `%(email2cid)s` as e
        UNION ALL
        SELECT
            "yuid" as id_type,
            yuid as id_value,
            cid
        FROM `%(yuid2cid)s` as e
    );

    INSERT INTO `%(all_cid)s` WITH TRUNCATE
    SELECT
        input.external_id as external_id,
        input.`timestamp` as `timestamp`,
        input.id_type as id_type,
        input.id_value as id_value,
        crypta.cid as cid
    FROM `%(input)s` as input
    INNER JOIN $crypta as crypta
    USING(id_type, id_value)
    ORDER BY cid
""")


join_yuid_query = dedent("""
    INSERT INTO `%(all_yuid)s` WITH TRUNCATE
    SELECT
        all_cid.external_id as external_id,
        all_cid.`timestamp` as `timestamp`,
        all_cid.id_type as id_type,
        all_cid.id_value as id_value,
        all_cid.cid as cid,
        cid_to_all.id_value as yuid
    FROM `%(all_cid)s` as all_cid
    INNER JOIN `%(cid_to_all)s` as cid_to_all
    USING(cid)
    WHERE cid_to_all.id_type = 'yandexuid'
    ORDER BY external_id
""")
