#!/usr/bin/env python
#  -*- coding: utf-8 -*-
from __future__ import division
import sys
reload(sys)
sys.setdefaultencoding("utf-8")

__author__ = 'aalogachev'

# import modules & set up logging
import gensim, logging
from gensim import matutils
from numpy import array
from scipy.cluster.vq import vq, kmeans, kmeans2
from sklearn.neighbors import KDTree
from ru.yandex.simularity.my_sentences import MySentences
from ru.yandex import utils
import os, ConfigParser


def topic_to_string(topic):
    return u' '.join(topic)

def get_vector(model, clear_sent):
    res = [model[word] for word in clear_sent]
    res = matutils.unitvec(array(res).mean(axis=0))
    return res

def get_vector_str(vector):
    is_first = True
    for num in vector:
        if is_first:
            vector_str = str(num)
            is_first = False
        else:
            vector_str =  vector_str + '\t' + str(num)
    return vector_str

def measure_and_write(file_name, out_file_name, topics, model, collect_in_memory, threshold):
    out_full_pred = open(out_file_name, mode='w')

    total_sim_sum = [0]*len(topics)
    line_counter = 0
    processed_counter = 0
    observ = []
    all_lines = []
    for line in open(file_name, mode='r'):

        #only first column is sentence
        line = line.strip()
        sent = line.split('\t')[0]
        sent = sent.split()
        line_counter += 1

        clear_sent = []
        for s in sent:
            if s in model:
                clear_sent.append(s)

        if len(clear_sent) > 0:
            out_full_pred.write(line+'\t')
            processed_counter += 1

            i = 0
            max_sim = None
            for topic in topics:
                sim = model.n_similarity(topic, clear_sent)
                out_full_pred.write(u'{0}\t'.format(sim))
                #counting total sum for simularity
                total_sim_sum[i] += sim
                i += 1
                if (max_sim is None or max_sim < sim):
                    max_sim = sim

            if collect_in_memory and max_sim>=threshold:
                vector = get_vector(model, clear_sent)
                out_full_pred.write(u'{0}'.format(get_vector_str(vector)))
                #vector = matutils.unitvec(vector)
                observ.append(vector)
                all_lines.append(line)

            out_full_pred.write('\n')

        if (line_counter % 1000 == 0):
            logger.info(u'Readed {0} lines, processed {1} lines'.format(line_counter, processed_counter))

    out_full_pred.close()

    return (observ, all_lines)

def split_by_topics(in_file, out_file, topics):
    #!NOTE assuming file format is "query<\t>topic_mark0<\t>topic_mark1..."
    all_weight = [0] * len(topics)
    line_count = 0
    #getting sum of all weights for topics
    for line in open(in_file, mode='r'):
        fields = line.strip().split('\t')
        line_count+=1
        if line_count%1000 == 0: logger.info(u'Lines readed = {0}'.format(line_count))
        for i in range(0, len(topics)):
            all_weight[i]+=float(fields[i+1])

    for i in range(0, len(all_weight)):
        all_weight[i]=all_weight[i]/line_count


    logger.info(u'Topics weights = {0}'.format(all_weight))

    all_weight = [1] * len(topics)
    out_stream = open(out_file, mode='w')
    for line in open(in_file, mode='r'):
        fields = line.strip().split('\t')
        max_for_topic = None
        topic_index = None
        for i in range(0, len(topics)):
            topic_weight = float(fields[i+1])/all_weight[i]
            if (max_for_topic is None) or max_for_topic < topic_weight:
                topic_index = i
                max_for_topic = topic_weight
        out_stream.write(u'{0}\t{1}\t{2}\n'.format(fields[0], topic_index, topic_to_string(topics[topic_index])))

    out_stream.close()

MOST_POPULAR_COUNT = 10

logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)


config = ConfigParser.SafeConfigParser()
config.readfp(open(os.path.join(utils.get_project_path(),'config_queries.json')))
section = config.defaults()['active_section']
logger.debug('Using config section =' + section)
work_dir = config.get(section=section, option='work_dir')
logger.debug("Using work dir = " + work_dir)
train_dir = config.get(section, 'train_dir')
train_file = config.get(section, 'train_file')
model_file = config.get(section, 'model_file')

full_file = config.get(section, 'full_file')
out_full_pred_file = config.get(section, 'out_full_pred_file')

classification_file = config.get(section, 'classification_file')
out_pred_file = config.get(section, 'out_pred_file')
centorids_file = config.get(section, 'centorids_file')
clusters_file = config.get(section, 'clusters_file')

# train_dir = '/Users/aalogachev/Projects/AdWords/queries/'
# train_file = train_dir + 'search_2015-06-22_sample'
# model_file = train_dir + 'queries_full.bin'
#
# full_file = train_file
# out_full_pred_file = work_dir + 'queries_pred.txt'
#
# classification_file = work_dir + 'top_queries.txt'
# #FIXME perform local learn for queries
# #model_local_file = work_dir + 'model_fines.bin'
# out_pred_file = work_dir + 'queries_top_pred.txt'
# centorids_file = work_dir + 'centroids.txt'
# clusters_file = work_dir + 'clusters.txt'

threshold = 0
do_train = False
do_full = True
init_knn = False
do_topics_split = True

model = None

if do_train:
    sentences = MySentences(train_file)
    model = gensim.models.Word2Vec(sentences, workers=8)
    model.save(model_file)
else:
    model = gensim.models.Word2Vec.load(model_file)

#fines
#topics = [['оплатить', 'штраф', 'гибдд'], ['проверить', 'штраф', 'гибдд']]

#serials
#topics = [['смотреть', 'сериал'], ['найти', 'сериал'], ['сериал']]
#topics = [['переводчик'], ['как', 'на','английском'], ['перевод', 'на', 'русский']]
#topics = [["перевод"], ["переводчик"], ["перевод", "онлайн"], ["переводчик", "онлайн"], ["translate"], ["англо-русский"], ["машинный", "перевод"]]
#topics = [['переводчик','английский', 'русский'], ['переводчик','украинский', 'русский'], ['переводчик','турецкий', 'русский']]
#topics = [['рецепт', 'блюд']]
topics = [['салат'], ['первое'], ['суп'], ['десерт'], ['сладкое'], ['торт'],  ['горячее'], ['второе']]
#full_file : %(train_dir)ssearch_2015-06-22_sample

tree = None

if do_full:
    observ, all_lines = measure_and_write(full_file, out_full_pred_file, topics, model, False, threshold)
    if init_knn:
        tree = KDTree(observ, 20)
    if do_topics_split:
        split_by_topics(out_full_pred_file, out_full_pred_file + '.topics', topics)

observ, all_lines = measure_and_write(classification_file, out_pred_file, topics, model, True, threshold)

if init_knn:
    dist, inds = tree.query(get_vector(model, topics[0]), k=20)
    for ind in inds[0]:
        logger.info('query[{0}]={1}'.format(ind, all_lines[ind]))

#k-mens clustering
obs_array = array(observ)
centroids,_ = kmeans(obs_array, 100)
out_file = open(centorids_file, mode='w')
out_file.write('{0}\n'.format(centroids))
out_file.close()

#k-means marking
cluster_marks,_ = vq(obs=obs_array, code_book=centroids)
out_file2 = open(clusters_file, mode='w')
for i in range (0, len(cluster_marks)):
    out_file2.write('{0}\t{1}\n'.format(cluster_marks[i], all_lines[i]))
out_file2.close()

for topic in topics:
    logger.info(u'{0} most popular for {1}'.format(MOST_POPULAR_COUNT, topic).encode('UTF-8'))
    simset = model.most_similar(positive=topic, topn=MOST_POPULAR_COUNT)
    for pair in simset:
        logger.info(pair[0])
