import random
from textwrap import dedent

VAL_MARK = 'val'
CONFUSED_MARK = 'confused'


rm_complete_dups_q = dedent("""
    INSERT INTO `%(out_path)s` WITH TRUNCATE
    SELECT DISTINCT
        `id_type`,
        `id_value`,
        `external_id`,
        `ticket`,
        `partner`,
        `retro_date`,
        `upper_bound_date`,
        `target`
    FROM `%(in_path)s`
""")


clean_cse_q = dedent("""
    INSERT INTO `%(out_path)s` WITH TRUNCATE
    SELECT
        `id_type`,
        `id_value`,
        (`external_id` ?? "") || '_' ||
        (`retro_date` ?? "") || '_' ||
        (`ticket` ?? "") as `external_id`,
        `ticket`,
        `partner`,
        `retro_date`,
        `upper_bound_date`,
        `target`
    FROM `%(in_path)s`
    WHERE `target` >= 0 and
          `retro_date` > '%(min_retro_date)s' and
          (`partner` ?? "") NOT IN %(no_go_partners)s
    ORDER BY external_id;
""")


get_eids_q = dedent("""
    INSERT INTO `%(out_path)s` WITH TRUNCATE
    SELECT DISTINCT `external_id`
    FROM `%(in_path)s`
    WHERE `target` = %(target)s
""")


def rows_sample(yt_client, table_path, sample_size, seed=42):
    random.seed(seed)
    rows = list(yt_client.read_table(table_path, enable_read_parallel=True))
    return random.sample(rows, sample_size)


filter_eids_q = dedent("""
    INSERT INTO `%(out_path)s` WITH TRUNCATE
    SELECT cse.*
    FROM `%(in_path)s` as cse
    INNER JOIN `%(eids_table)s`
    USING(`external_id`)
    ORDER BY `external_id`
""")


join_cids_q = dedent("""
    INSERT INTO `%(out_path)s` WITH TRUNCATE
    SELECT DISTINCT
        cse.`external_id` as `external_id`,
        crypta.`cid` as `cid`
    FROM `%(cse_interesting_eids)s` as cse
    INNER JOIN `%(id_value_to_cid)s` as crypta
    USING(`id_type`, `id_value`)
""")


def glue_eids(eid2ids):
    class Point:
        def __init__(self, eid):
            self.eid = eid
            self.connected_with = []

        def add_cpnnection(self, point):
            self.connected_with.append(point)
            point.connected_with.append(self)

    id2point = dict()
    eid2point = dict()
    for eid, ids in eid2ids.iteritems():
        cur_point = Point(eid)
        eid2point[eid] = cur_point
        for idr in ids:
            if idr not in id2point:
                id2point[idr] = cur_point
            else:
                eid_prev = id2point[idr].eid
                eid2point[eid_prev].add_cpnnection(cur_point)

    new_eids = []
    seen_eids = set()
    for eid in eid2ids:
        cumulative_id = []
        stack = [eid2point[eid]]
        while stack:
            cur_p = stack.pop()
            if cur_p.eid not in seen_eids:
                seen_eids.add(cur_p.eid)
                cumulative_id.append(cur_p.eid)
                stack.extend(cur_p.connected_with)

        if len(cumulative_id) > 0:
            new_eids.append(cumulative_id)

    return new_eids


def list_suffixes(cconfig):
    return range(1, cconfig.n_folds + 1) + ['val']


class MarkCidsMapper:
    def __init__(self, nfolds, val_sample_rate=0):
        self.nfolds = nfolds
        self.val_sample_rate = int(round(val_sample_rate))

    def __call__(self, rec):
        cid = int(rec['cid'])
        if self.val_sample_rate and cid % self.val_sample_rate == 0:
            mark = VAL_MARK
        else:
            mark = str(cid % self.nfolds + 1)

        yield {
            'external_id': rec['external_id'],
            'cid': rec['cid'],
            'mark': mark
        }


def mark_eids_reducer(key, recs):
    cids_marks = set(rec['mark'] for rec in recs)
    if len(cids_marks) > 1:
        mark = CONFUSED_MARK
    else:
        mark = cids_marks.pop()

    yield {
        'external_id': key['external_id'],
        'mark': mark
    }


filter_by_eids_q = dedent("""
    INSERT INTO `%(out_path)s` WITH TRUNCATE
    SELECT cse.*
    FROM `%(cse)s` as cse
    INNER JOIN `%(marked_eids)s` as marked_eids
    USING(`external_id`)
    WHERE marked_eids.`mark` = '%(mark)s'
""")


def compact_reducer(key, recs):
    phone_id_value = []
    email_id_value = []

    for rec in recs:
        if rec['id_type'] == 'email_md5':
            email_id_value.append(rec['id_value'])
        elif rec['id_type'] == 'phone_md5':
            phone_id_value.append(rec['id_value'])
        else:
            raise ValueError('Bad id_type {}'.format(rec['id_type']))

    assert phone_id_value or email_id_value

    yield {
        'external_id': key['external_id'],
        'phone_id_value': ','.join(phone_id_value) or None,
        'email_id_value': ','.join(email_id_value) or None,
        'retro_date': key['retro_date'],
        'partner': key['partner'],
        'target': key['target']
    }


def make_file_name(suffix):
    return 'normalized_{}.tsv'.format(str(suffix))
