#!/usr/bin/env python
# coding: utf-8

import sys
reload(sys)
sys.setdefaultencoding("utf-8")

import os
from datetime import datetime
import yt.wrapper as yt

# https://sandbox.yandex-team.ru/resource/451408451
# https://a.yandex-team.ru/arc/history/trunk/arcadia/search/wizard/data/wizard/LingBoost/ya.make


def get_queries_embeddings(queries):
    TMP_QUERIES_FILE = 'queries_for_dssm.txt'
    RESULT_FILE = 'embeddings.txt'
    DSSM_MODEL = 'lingboost.dssm.model'
    with open(TMP_QUERIES_FILE, 'w') as f:
        for q in queries:
            f.write('{}\t\n'.format(q['query_text']))
    print datetime.now(), 'queries written'

    cmd = 'cat {} | ./kernel/dssm_applier/nn_applier/nn_applier apply -m {} --header query,expansion -o query_embedding > {}'.format(TMP_QUERIES_FILE, DSSM_MODEL, RESULT_FILE)
    print cmd
    os.system(cmd)
    os.system('rm {}'.format(TMP_QUERIES_FILE))

    result = {}
    for emb in open(RESULT_FILE):
        yield [float(x) for x in emb.split(' ') if len(x) > 0]


def append_dssm_embed(job_root, params):
    country, platform, basket_type = params
    in_table = '{job_root}/{country}/{platform}/05_{basket_type}_merged_parts_cleared_dups_1'.format(
        job_root=job_root,
        country=country,
        platform=platform,
        basket_type=basket_type
    )
    out_table = '{job_root}/{country}/{platform}/06_{basket_type}_merged_parts_query_embeds_raw'.format(
        job_root=job_root,
        country=country,
        platform=platform,
        basket_type=basket_type
    )
    data = list(yt.read_table(in_table))

    print datetime.now(), 'start ({}, {}, {}), {} queries'.format(country, platform, basket_type, len(data))
    embs = get_queries_embeddings(data)
    print datetime.now(), 'embeddings done'

    def data_with_embed_generator(data, embs):
        for q, emb in zip(data, embs):
            q['query_embed'] = emb
            yield q

    yt.write_table(out_table, data_with_embed_generator(data, embs))

    return out_table


def main(*args):
    queries_list, in2, in3, token, embed_key, html_file = args

    job_root='//home/images/dev/nerevar/baskets_img/2018Q1_v2'

    tables_list = []
    for tup in itertools.product(
        ['BY'], # 'RU', 'UA', 'KZ', 'UZ', 'exUSSR'
        ['desktop'], # 'touch'
        ['kpi'], # 'validate'
    ):
        table_name = append_dssm_embed(job_root, tup)

        tables_list.append(table_name)
    return tables_list
