# -*- coding: utf-8 -*-
import time
import copy
import numpy as np
import yt.wrapper as yt
from yt.wrapper import ypath_join
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.yql.yql_helpers import execute_yql
from datacloud.dev_utils.yt import yt_utils
from datacloud.config.yt import YT_PROXY

logger = get_basic_logger(__name__)

TIMESHIFT = 0
SECONDS_IN_DAY = 24 * 60 * 60
SPEC = {
    'tentative_pool_trees': ['cloud'],
    'pool_trees': ['physical'],
    'title': '[Custom CountVectorizer Dict]',
    'max_failed_job_count': 100
}


def run_step(result_table, step_function):
    if not yt.exists('result_table') and yt.row_count > 0:
        step_function


class DoSemanticSearch(object):
    def __init__(self, project_folder, target_names, yt_client, history_length=180, suffix=None, min_amount=100, min_ids=5):
        self.project_folder = project_folder
        self.target_names = target_names
        self.yt_client = yt_client
        self.history_length = history_length * SECONDS_IN_DAY
        self.ticket_name = project_folder.split('/')[-1]
        if suffix is None:
            self.garbage_folder = ypath_join(project_folder, 'CV_garbage')
        else:
            self.garbage_folder = ypath_join(project_folder, 'CV_garbage' + '_' + suffix)

        if not self.yt_client.exists(self.garbage_folder):
            self.yt_client.mkdir(self.garbage_folder)
        self.input_yuid = ypath_join(project_folder, 'input_yuid')
        self.min_amount = min_amount
        self.min_ids = min_ids

        # important tables:
        self.lemmered_filtered_indexed_words = self.make_result_table('lemmered_filtered_indexed_words')
        self.words_by_yuids = self.make_result_table('words_by_yuids')
        self.listed_words = self.make_result_table('listed_words')
        self.listed_words_with_targets = self.make_result_table('listed_words_with_targets')
        self.result_indexes = self.make_result_table('result_indexes')
        self.final_listed_words_with_targets = self.make_result_table('final_listed_words_with_targets')
        self.model_results = self.make_result_table('model_results')

        # intermediate tables
        self.indexed_useful_dict = ypath_join(self.garbage_folder, 'indexed_useful_dict')

        # doing steps
        self.steps = []
        self.add_step('Looking for the all words', self.words_by_yuids, self.step1)
        self.add_step('Doing lemmantization and filtration', self.lemmered_filtered_indexed_words, self.step2)
        self.add_step('Making lists for CSR matrix', self.listed_words, self.step3)
        self.add_step('Adding targets', self.listed_words_with_targets, self.step4)
        self.add_step('Search of useful words set', self.result_indexes, self.step5)
        self.add_step('Only useful word counts', self.final_listed_words_with_targets, self.step6)
        self.add_step('Make final predictions', self.model_results, self.step7)

        logger.info('Object for ticket {} initialized'.format(self.ticket_name))

    def __call__(self):
        for i, step in enumerate(self.steps):
            logger.info('Step#{}: [{}]'.format(i + 1, step['step_name']))
            self.run_step(step)

    def make_result_table(self, name):
        return ypath_join(self.garbage_folder, name)

    def add_step(self, step_name, result_table, step_function):
        self.steps.append({
            'step_name': step_name,
            'result_table': result_table,
            'step_function': step_function,
            'status': 'Ready'
        })

    def run_step(self, step):
        self.check_status(step)
        if step['status'] == 'Done':
            logger.info('Table is ready. Step was done before.')
        else:
            logger.info('Step [{}] is started'.format(step['step_name']))
            step['step_function'](step['result_table'], step['step_name'], self.yt_client)
        step['status'] == 'Done'
        logger.info('Step [{}] is Done'.format(step['step_name']))

    def check_status(self, step):
        if self.yt_client.exists(step['result_table']) and self.yt_client.row_count(step['result_table']) > 0:
            step['status'] = 'Done'

    def title_add_step_name(self, step_name):
        step_spec = copy.copy(SPEC)
        step_spec['title'] += step_name
        return step_spec

    def step1(self, result_table, step_name, yt_client):
        step_spec = self.title_add_step_name(step_name)
        spy_tables = yt_client.list(ypath_join(self.project_folder, 'datacloud/grep/spy_log'), absolute=True)
        watch_tables = yt_client.list(ypath_join(self.project_folder, 'datacloud/grep/watch_log_tskv'), absolute=True)
        all_tables = spy_tables + watch_tables
        yt_client.run_reduce(
            ReduceRetroDateAndSplit(self.history_length),
            [self.input_yuid] + all_tables,
            result_table,
            reduce_by=['external_id', 'yuid'],
            spec=step_spec
        )

    def step2(self, result_table, step_name, yt_client):
        lemmered_words_count, lemmered_dict_filtered = yql_lemmatize(
            self.words_by_yuids,
            self.garbage_folder,
            self.project_folder,
            self.yt_client,
            self.min_amount,
            self.min_ids
        )
        # result_table = yql_add_indexes(lemmered_words_count, lemmered_dict_filtered, self.garbage_folder, yt_client)

    def step3(self, result_table, step_name, yt_client):
        step_spec = self.title_add_step_name(step_name)
        yt_client.run_sort(self.lemmered_filtered_indexed_words, sort_by=['external_id'], spec=step_spec)
        yt_client.run_reduce(
            make_lists,
            self.lemmered_filtered_indexed_words,
            result_table,
            reduce_by=['external_id'],
            spec=step_spec
        )

    def step4(self, result_table, step_name, yt_client):
        step_spec = self.title_add_step_name(step_name)
        yt_client.run_sort(self.listed_words, sort_by='external_id', spec=step_spec)
        yt_client.run_reduce(
            JoinTargetReduce(self.target_names),
            [self.listed_words, self.input_yuid],
            result_table,
            reduce_by='external_id',
            spec=step_spec
        )

    def step5(self, result_table, step_name, yt_client):
        self.lemmered_dict_filtered = self.make_result_table('lemmered_dict_filtered')
        dict_size = yt_client.row_count(self.lemmered_dict_filtered)
        get_word_subset_lgb(
            self.listed_words_with_targets,
            self.target_names,
            self.garbage_folder,
            dict_size,
            self.ticket_name,
            yt_client
        )

    def step6(self, result_table, step_name, yt_client):
        step_spec = self.title_add_step_name(step_name)
        final_word_set = set()
        new2old_indexes_dict = ypath_join(self.garbage_folder, 'new2old_indexes_dict')
        for rec in yt_client.read_table(self.result_indexes):
            final_word_set = final_word_set | set(rec['word_indexes'])
        new2old = [{'new_index': i, 'index': j} for i, j in enumerate(final_word_set)]
        yt_client.write_table(new2old_indexes_dict, new2old)
        yql_add_token_names(new2old_indexes_dict, self.lemmered_dict_filtered, self.indexed_useful_dict, yt_client)
        yt_client.run_map(
            map_only_my_indexes(final_word_set),
            self.listed_words_with_targets,
            result_table,
            spec=step_spec
        )

    def step7(self, result_table, step_name, yt_client):
        learn_final_model(
            self.final_listed_words_with_targets,
            result_table,
            self.garbage_folder,
            self.target_names,
            self.yt_client.row_count(self.indexed_useful_dict),
            self.ticket_name,
            yt_client
        )


@yt.with_context
class ReduceRetroDateAndSplit():
    def __init__(self, history_length=180 * 60 * 60 * 24):
        self.history_length = history_length

    def check_retro_date(self, timestamp, retro_date, history_length):
        before = retro_date - timestamp - TIMESHIFT > 0
        after = retro_date - timestamp < history_length
        return before and after

    def __call__(self, key, recs, context):
        from collections import Counter
        import re
        retro_found = False
        for rec in recs:
            if context.table_index == 0:
                retro_timestamp = rec['timestamp']
                retro_found = True
            elif retro_found:
                if self.check_retro_date(rec['timestamp'], retro_timestamp, self.history_length) and rec['title'] is not None:
                    try:
                        words = re.findall(ur'(?u)\b\w\w+\b', (rec['title']).decode('utf-8'))
                        word_counter = Counter(words)
                        for key, val in word_counter.iteritems():
                            if key != '':
                                yield {
                                    'external_id': rec['external_id'],
                                    'yuid': rec['yuid'],
                                    'token': key,
                                    'amount': val
                                }
                    except UnicodeDecodeError:
                        pass
                else:
                    pass
            else:
                raise 'Some bugs in the table pathes'


class map_only_my_indexes():
    def __init__(self, word_set):
        self.word_set = word_set
        self.new_indexes = {j: i for i, j in enumerate(word_set)}

    def __call__(self, rec):
        new_rec = copy.copy(rec)
        new_rec['indexes'] = []
        new_rec['amounts'] = []
        for i, j in zip(rec['indexes'], rec['amounts']):
            if i in self.word_set:
                new_rec['indexes'].append(self.new_indexes[i])
                new_rec['amounts'].append(j)

        yield new_rec


def yql_lemmatize(raw_words_count, garbage_folder, project_folder, yt_client, min_amount=100, min_ids=5):
    """Function does lemantization, filtration and indexing of rest of the words
    """

    query = """
        pragma yt.ForceInferSchema;
        PRAGMA yt.PoolTrees = "physical";
        PRAGMA yt.TentativePoolTrees = "cloud";


        $to_words = AugurTokenizer::ToWords("", "lemmer-norm");


        $step0 = (
        SELECT
            external_id,
            token,
            $to_words(token, "ru") as token_lemmered,
            amount

        FROM `%(raw_words_count)s`
        );


        $step1 = (
        SELECT
            external_id,
            token,
            token_lemmered,
            amount
        FROM $step0
        FLATTEN BY token_lemmered
        );


        INSERT INTO `%(lemmered_words_count)s` WITH TRUNCATE
        SELECT
            external_id,
            token_lemmered,
            sum(amount) as amount

        FROM $step1
        GROUP BY external_id, token_lemmered;

        $step2 = (
        SELECT
            token_lemmered,
            sum(amount) as amount,
            count(external_id) as ids

        FROM $step1
        GROUP BY token_lemmered
        );

        INSERT INTO `%(lemmered_dict)s` WITH TRUNCATE
        SELECT
        *
        from $step2;

        insert into `%(lemmered_dict_filtered)s` with truncate
        SELECT
            token_lemmered,
            amount

        from $step2
        WHERE amount > %(min_amount)s and ids > %(min_ids)s;

        COMMIT;

        INSERT INTO `%(lemmered_dict_filtered)s` WITH TRUNCATE
        select
        (TableRecordIndex() - 1) as index,
        token_lemmered,
        amount
        from `%(lemmered_dict_filtered)s`
        ORDER BY index
    """

    lemmered_words_count = ypath_join(garbage_folder, 'lemmered_words_count')
    lemmered_dict = ypath_join(garbage_folder, 'lemmered_dict')
    lemmered_dict_filtered = ypath_join(garbage_folder, 'lemmered_dict_filtered')
    execute_yql(
        query=query,
        params={
            'raw_words_count': raw_words_count,
            'lemmered_words_count': lemmered_words_count,
            'lemmered_dict': lemmered_dict,
            'lemmered_dict_filtered': lemmered_dict_filtered,
            'min_amount': min_amount,
            'min_ids': min_ids
        },
        yt_client=yt_client
    )
    return lemmered_words_count, lemmered_dict_filtered


def yql_add_indexes(lemmered_words_count, lemmered_dict_filtered, garbage_folder, yt_client):
    output_table = ypath_join(garbage_folder, 'lemmered_filtered_indexed_words')
    query = """
    pragma yt.ForceInferSchema;
    PRAGMA yt.PoolTrees = "physical";
    PRAGMA yt.TentativePoolTrees = "cloud";

    INSERT INTO `%(output_table)s` WITH TRUNCATE
    SELECT
        a.external_id as external_id,
        a.token_lemmered as token_lemmered,
        a.amount as amount,
        b.index as index

    FROM `%(lemmered_words_count)s` as a
    inner join `%(lemmered_dict_filtered)s` as b
    on a.token_lemmered == b.token_lemmered
    """
    execute_yql(
        query=query,
        params={
            'output_table': output_table,
            'lemmered_words_count': lemmered_words_count,
            'lemmered_dict_filtered': lemmered_dict_filtered
        },
        yt_client=yt_client
    )
    return output_table


def yql_add_token_names(new2old_indexes_dict, lemmered_dict_filtered, indexed_useful_dict, yt_client):
    query = """
    use {yt_cluster};
    pragma yt.ForceInferSchema;

    INSERT INTO `%(indexed_useful_dict)s` WITH TRUNCATE
    SELECT
        a.new_index as new_index,
        a.index as old_index,
        b.token_lemmered as token,
        b.amount as amount

    FROM `%(new2old_indexes_dict)s` as a
    JOIN `%(lemmered_dict_filtered)s` as b
    ON a.index == b.index
    """.format(yt_cluster=YT_PROXY)
    execute_yql(
        query=query,
        params={
            'indexed_useful_dict': indexed_useful_dict,
            'new2old_indexes_dict': new2old_indexes_dict,
            'lemmered_dict_filtered': lemmered_dict_filtered
        },
        yt_client=yt_client
    )


def make_lists(key, recs):
    indexes = []
    amounts = []
    for rec in recs:
        amounts.append(rec['amount'])
        indexes.append(rec['index'])

    yield {
        'external_id': key['external_id'],
        'amounts': amounts,
        'indexes': indexes
    }


@yt.with_context
class JoinTargetReduce():
    def __init__(self, targets):
        self.targets = targets

    def __call__(self, key, recs, context):
        found_target = False
        found_cv_features = False
        out_rec = dict()
        for rec in recs:
            if context.table_index == 0:
                found_cv_features = True
                out_rec['amounts'] = rec['amounts']
                out_rec['indexes'] = rec['indexes']
                out_rec['external_id'] = rec['external_id']
            else:
                found_target = True
                for target in self.targets:
                    out_rec[target] = rec[target]

        if found_cv_features and found_target:
            yield out_rec


def parse_row(row, length):
    from scipy import sparse
    return sparse.coo_matrix(
        (row['amounts'], (np.zeros(len(row['indexes'])), row['indexes'])),
        shape=(1, length),
        dtype=np.float32
    )


def only_train_targets(rec, target_names):
    is_train = True
    for target_name in target_names:
        is_train = is_train and (rec[target_name] in [0, 1])
    return is_train


def get_word_subset_lgb(
        input_table,
        target_names,
        garbage_folder,
        dict_size,
        ticket_name,
        yt_client,
        train_folds=True):
    from scipy import sparse

    i = 0
    targets = dict()
    external_id = []
    final_words = dict()
    for target_name in target_names:
        targets[target_name] = []
    X = [[]]
    prev_i = 0
    t_start = time.time()
    for row in yt_client.read_table(input_table):
        if only_train_targets(row, target_names) and (row[target_names[0]] == 1 or np.random.randint(2) == 0):
            # if True:
            X[-1].append(parse_row(row, dict_size))
            external_id.append(row['external_id'])
            for target_name in target_names:
                targets[target_name].append(row[target_name])
            i += 1

        if i % 1e4 == 0 and i != prev_i:
            logger.info('Loaded {} rows, {} s passed'.format(i, time.time() - t_start))
            X.append([])
            prev_i = i

    logger.info('{} rows loaded!'.format(i))
    X = sparse.vstack(map(sparse.vstack, X))

    from model import train, train_folds
    recs = []
    for target_name in target_names:
        ix_train = np.array(targets[target_name]) >= 0
        X = X.tocsr()
        X_train = X[np.where(ix_train)[0], :]
        y_train = np.array(targets[target_name])[ix_train]
        logger.info('{}'.format(np.unique(y_train)))
        if train_folds:
            clfs, _ = train_folds(X_train, y_train, [], ticket_name + '_' + target_name)
            word_set = {}
            for clf in clfs:
                if len(word_set) == 0:
                    word_set = set(np.argsort(clf.feature_importance())[-1 * sum(clf.feature_importance() > 0):])
                else:
                    word_set = word_set & set(np.argsort(clf.feature_importance())[-1 * sum(clf.feature_importance() > 0):])
            logger.info('For target [{}], word set has length: {}'.format(target_name, len(word_set)))
            final_words[target_name] = list(word_set)
        else:
            clf, _ = train(X_train, y_train, [], ticket_name + '_' + target_name)
            final_words[target_name] = list(set(np.argsort(clf.feature_importance())[-1 * sum(clf.feature_importance() > 0):]))
        recs.append({
            'target_name': target_name,
            'word_indexes': final_words[target_name],
            'total_words': len(final_words[target_name])
        })

    # ONLY for TCS:
    # y1 = np.array(targets['flg1'])
    # y2 = np.array(targets['flg2'])
    # ix = y1 == 1
    # X = X.tocsr()
    # clf, _ = train(X[ix, :], y2[ix], [], ticket_name + '_1to2')
    # final_words['1to2'] = list(set(np.argsort(clf.feature_importance())[-1*sum(clf.feature_importance() >0):]))
    # recs.append({
    #         'target_name': '1to2',
    #         'word_indexes': final_words['1to2'],
    #         'total_words': len(final_words['1to2'])
    #     })
    yt_client.write_table(yt.FilePath(ypath_join(garbage_folder, 'result_indexes'), append=True), recs)


def learn_final_model(input_table, result_table, garbage_folder, target_names, final_dict_len, ticket_name, yt_client):
    from scipy import sparse
    from model import train
    i = 0
    external_id = []
    X = [[]]
    prev_i = 0
    targets = dict()
    for target_name in target_names:
        targets[target_name] = []
    t_start = time.time()
    for row in yt_client.read_table(input_table):
        if only_train_targets(row, target_names):
            # if True:
            X[-1].append(parse_row(row, final_dict_len))
            for target_name in target_names:
                targets[target_name].append(row[target_name])
            external_id.append(row['external_id'])
            i += 1

        if i % 1e4 == 0 and i != prev_i:
            logger.info('Loaded {} rows, elapsed time: {} seconds'.format(i, time.time() - t_start))
            X.append([])
            prev_i = i

    X = sparse.vstack(map(sparse.vstack, X))
    recs = []
    feature_importance = []
    for target_name in target_names:
        ix_train = np.array(targets[target_name]) >= 0
        X = X.tocsr()
        X_train = X[np.where(ix_train)[0], :]
        y_train = np.array(targets[target_name])[ix_train]
        logger.info('{}'.format(np.unique(y_train)))
        model_name = ticket_name + '_' + target_name
        model, results = train(X_train, y_train, [], model_name)
        feature_importance.append({
            'target_name': target_name,
            'feature_importance': model.feature_importance()
        })
        results['target'] = target_name
        recs.append(results)

    # y1 = np.array(targets['flg1'])
    # y2 = np.array(targets['flg2'])
    # ix = y1 == 1
    # X = X.tocsr()
    # _, results = train(X[ix, :], y2[ix], [], ticket_name + '_1to2')
    # results['target'] = target_name
    # recs.append(results)
    feature_importance_table = ypath_join(garbage_folder, 'feature_importance')
    yt_client.write_table(feature_importance_table, feature_importance)
    yt_client.write_table(result_table, recs)


if __name__ == '__main__':

    # do_sematic_search(
    #     '//projects/scoring/tcs/XPROD-993',
    #     ['target_flg3'], get_yt_client(YT_PROXY),
    #     history_length=180*SECONDS_IN_DAY,
    #     suffix='180days'
    # )

    DoSemanticSearch(
        '//projects/scoring/tcs/XPROD-1251',
        ['target'],
        yt_utils.get_yt_client(YT_PROXY),
        history_length=180,
        suffix='180',
        min_amount=0,
        min_ids=0
    )()
