#!/usr/bin/env python
# coding: utf-8

import sys
reload(sys)
sys.setdefaultencoding("utf-8")

import os
from datetime import datetime
import yt.wrapper as yt
import itertools
from scipy import spatial
from collections import defaultdict

THR = 0.99

def query_key(item):
    return (item['query_text'], item['query_region_id'], item['query_country'], item['query_device'], item['query_uid'])


def get_query_group_total_freq(query, groups):
    return query[1] + sum([x[1] for x in groups.get(query, [])])


def deduplicate(queries_list):
    embeds = {query_key(x): x['query_embed'] for x in queries_list}
    query_keys = [query_key(x) for x in queries_list]
    queries_data = {query_key(x): x for x in queries_list}

    print datetime.now(), 'start, {} queries, {} combs'.format(len(query_keys), len(query_keys) * len(query_keys) / 2)

    # шаг 1 — собираем группы похожих запросов
    groups = defaultdict(list)
    used = {}
    for q1, q2 in itertools.combinations(query_keys, r=2):
        if q1 in used and q2 in used:
            continue
        sim = 1 - spatial.distance.cosine(embeds[q1], embeds[q2])
        if sim >= THR:
            if q1 in used:
                if q2 not in groups[used[q1]]:
                    groups[used[q1]].append(q2)
                used[q2] = used[q1]
            elif q2 in used:
                if q1 not in groups[used[q2]]:
                    groups[used[q2]].append(q1)
                used[q1] = used[q2]
            else:
                groups[q1].append(q2)
                used[q2] = q1
                used[q1] = q1
    print datetime.now(), 'got groups'

    # шаг 2 — выбор самого частотного запроса и сортировка
    for main_query, qlist in groups.items():
        max_query = main_query
        for q in qlist:
            if q[1] > max_query[1] or (q[1] == max_query[1] and len(q[0]) < len(max_query[0])):
                max_query = q
        if max_query[0] != main_query[0]:
            # ставим в вершину группы запрос с максимальной частотностью
            qlist.remove(max_query)
            qlist.append(main_query)
            del groups[main_query]
            # сортировка внутри группы по уменьшению count и увеличению длины запроса
            groups[max_query] = sorted(qlist, key=lambda x: (x[1], -len(x[0])), reverse=True)
        else:
            groups[main_query] = sorted(qlist, key=lambda x: (x[1], -len(x[0])), reverse=True)
    print datetime.now(), 'resorted'

    out = []
    # сортировка по уменьшению суммарной частоты группы и увеличению длины запроса
    for q in sorted(query_keys, key=lambda x: (get_query_group_total_freq(x, groups), -len(x[0])), reverse=True):
        freq = get_query_group_total_freq(q, groups)
        if q in groups:
            out.append(queries_data[q])
        elif q in used:
            continue
        else:
            out.append(queries_data[q])
    print datetime.now(), 'done'

    return out


def dedup_by_dssm(job_root, params):
    country, platform, basket_type = params
    in_table = '{job_root}/{country}/{platform}/06_{basket_type}_merged_parts_query_embeds_raw'.format(
        job_root=job_root,
        country=country,
        platform=platform,
        basket_type=basket_type
    )
    out_table = '{job_root}/{country}/{platform}/06_{basket_type}_merged_parts_query_embeds_cleared_dups_1'.format(
        job_root=job_root,
        country=country,
        platform=platform,
        basket_type=basket_type
    )
    data = list(yt.read_table(in_table))

    print datetime.now(), 'start ({}, {}, {}), {} queries'.format(country, platform, basket_type, len(data))
    out_data = deduplicate(data)
    print datetime.now(), 'embeddings done'

    yt.write_table(out_table, out_data)

    return out_table


def main(*args):
    queries_list, in2, in3, token, embed_key, html_file = args

    job_root='//home/images/dev/nerevar/baskets_img/2018Q1_v2'

    tables_list = []
    for tup in itertools.product(
        ['BY'], # 'RU', 'UA', 'KZ', 'UZ', 'exUSSR'
        ['desktop'], # 'touch'
        ['kpi'], # 'validate'
    ):
        table_name = dedup_by_dssm(job_root, tup)

        tables_list.append(table_name)
    return tables_list
