#!/usr/bin/env python
# coding: utf-8

"""
Среди массива запросов с векторами удалить совсем далёкие
"""

import sys
reload(sys)
sys.setdefaultencoding("utf-8")

import os
import json
import random
import itertools
import numpy as np
from scipy import spatial, stats
from datetime import date, datetime
from collections import defaultdict, Counter

MAX_PERCENT = 30
MIN_PERCENT = 10


def main(*args):
    queries_list, in2, in3, token, any_param, html_file = args
    embeds = {(x['query'], x['cnt']): x['embed'] for x in queries_list}
    queries_list = [(x['query'], x['cnt']) for x in queries_list] # list -> tuple

    print datetime.now(), 'start, {} queries, {} combs'.format(len(queries_list), len(queries_list) * len(queries_list) / 2)

    distances = defaultdict(float)
    distances_cnt = defaultdict(int)

    for q1, q2 in itertools.combinations(queries_list, r=2):
        dist = spatial.distance.cosine(embeds[q1], embeds[q2])
        if dist >= 0.147: # порог, после которого считаем запросы не из одной тематики
            distances[q1] += dist
            distances[q2] += dist
            distances_cnt[q1] += 1
            distances_cnt[q2] += 1

    print datetime.now(), 'combs done'

    avg_dists = [v/float(distances_cnt[q]) for q, v in distances.items()]

    data = [(q[0], q[1], distances[q] / float(distances_cnt[q])) for q in queries_list]

    std = np.std(avg_dists)
    mean = np.mean(avg_dists)
    percent = int(round(std * 1000))
    percent = max(min(percent, MAX_PERCENT), MIN_PERCENT)
    thr = np.percentile(avg_dists, 100 - percent)

    result = []
    outliers_text = ''
    for q in sorted(data, key=lambda x: x[2], reverse=True):
        key = (q[0], q[1])
        if q[2] < thr:
            result.append({
                'query': q[0],
                'cnt': q[1],
                'embed': embeds[key]
            })
        else:
            outliers_text += 'avg_dist: {:0.6f}, cnt: {:3.0f}, query: {}\n'.format(q[2], q[1], q[0])

    stats_str = 'mean_avg_dist {:0.4f}, std_avg_dist {:0.4f}, cut {}%, thr: {}, take items: {}/{}\n'.format(
        mean, std, percent, thr, len(result), len(queries_list)
    )
    print stats_str
    print stats.describe(avg_dists)
    html_file.write('<pre>')
    html_file.write(stats_str)
    html_file.write('outliers:\n' + outliers_text)

    print datetime.now(), 'done'

    return result
