#!/usr/bin/env python
# coding: utf-8

import sys
reload(sys)
sys.setdefaultencoding("utf-8")

import os
import json
import itertools
import subprocess
from scipy import spatial
from datetime import date, datetime
from collections import defaultdict


DSSM_MODEL = 'one_se_hard_am_ss_compressed.dssm'
THR = 0.98
QUERY_MIN_FREQUENCY = 5


def get_queries_embeddings(queries):
    TMP_QUERIES_FILE = 'queries_for_dssm.txt'
    RESULT_FILE = 'embeddings.txt'
    with open(TMP_QUERIES_FILE, 'w') as f:
        for q in queries:
            f.write('{}\n'.format(q[0]))
    print datetime.now(), 'queries written'

    cmd = 'cat {} | ./ml/dssm/dssm/dssm3 apply -m {} --header "query" -o query_embedding > {}'.format(TMP_QUERIES_FILE, DSSM_MODEL, RESULT_FILE)
    print cmd
    os.system(cmd)
    os.system('rm {}'.format(TMP_QUERIES_FILE))

    result = {}
    for q, emb in zip(queries, open(RESULT_FILE)):
        result[q] = [float(x) for x in emb.split(' ') if len(x) > 0]
    return result


def get_query_group_total_freq(query, groups):
    return query[1] + sum([x[1] for x in groups.get(query, [])])


def main(*args):
    queries_list, in2, in3, token, any_param, html_file = args
    queries_list = [(x[0], x[1]) for x in queries_list] # list -> tuple

    print datetime.now(), 'start, {} queries, {} combs'.format(len(queries_list), len(queries_list) * len(queries_list) / 2)
    embs = get_queries_embeddings(queries_list)
    print datetime.now(), 'embeddings done'

    # шаг 1 — собираем группы похожих запросов
    groups = defaultdict(list)
    used = {}
    for q1, q2 in itertools.combinations(queries_list, r=2):
        if q1 in used and q2 in used:
            continue
        sim = 1 - spatial.distance.cosine(embs[q1], embs[q2])
        if sim >= THR:
            if q1 in used:
                if q2 not in groups[used[q1]]:
                    groups[used[q1]].append(q2)
                used[q2] = used[q1]
            elif q2 in used:
                if q1 not in groups[used[q2]]:
                    groups[used[q2]].append(q1)
                used[q1] = used[q2]
            else:
                groups[q1].append(q2)
                used[q2] = q1
                used[q1] = q1
    print datetime.now(), 'got groups'

    # шаг 2 — выбор самого частотного запроса и сортировка
    for main_query, qlist in groups.items():
        max_query = main_query
        for q in qlist:
            if q[1] > max_query[1] or (q[1] == max_query[1] and len(q[0]) < len(max_query[0])):
                max_query = q
        if max_query[0] != main_query[0]:
            # ставим в вершину группы запрос с максимальной частотностью
            qlist.remove(max_query)
            qlist.append(main_query)
            del groups[main_query]
            # сортировка внутри группы по уменьшению count и увеличению длины запроса
            groups[max_query] = sorted(qlist, key=lambda x: (x[1], -len(x[0])), reverse=True)
        else:
            groups[main_query] = sorted(qlist, key=lambda x: (x[1], -len(x[0])), reverse=True)
    print datetime.now(), 'resorted'

    out = []
    # сортировка по уменьшению суммарной частоты группы и увеличению длины запроса
    for q in sorted(queries_list, key=lambda x: (get_query_group_total_freq(x, groups), -len(x[0])), reverse=True):
        freq = get_query_group_total_freq(q, groups)
        if q in groups:
            if freq > QUERY_MIN_FREQUENCY:
                html_file.write('{} ({})\n'.format(q[0], freq))
                out.append((q[0], freq))
                for item in groups[q]:
                    html_file.write('        {} ({})\n'.format(item[0], item[1]))
        elif q in used:
            continue
        else:
            html_file.write('{} ({})\n'.format(q[0], q[1]))
            if q[1] >= QUERY_MIN_FREQUENCY:
                out.append(q) # запросы с частотой QUERY_MIN_FREQUENCY и выше
    print datetime.now(), 'done'

    return out
