# -*- coding: utf-8 -*-
import sys
from library.python.nyt import client as nyt_client
nyt_client.initialize(list(map(lambda it: it.encode(), sys.argv)))
# nyt_client.initialize(sys.argv)
import os
import re
from datetime import datetime, timedelta
from enum import Enum, unique
import numpy as np
import yt.wrapper as yt_wrapper
from datacloud.dev_utils.time.patterns import RE_DAILY_LOG_FORMAT
from datacloud.dev_utils.yt import yt_utils
import datacloud.dev_utils.data.data_utils as du
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.features.cluster.path_config import PathConfig
from datacloud.features.cluster import extract_url
from datacloud.features.cluster import fast_clust


logger = get_basic_logger(__name__)


@unique
class ClustFeaturesBuildSteps(Enum):
    step_1_daily_hostnames_extract = 'step_1_daily_hostnames_extract'
    step_2_bow_reducer = 'step_2_bow_reducer'
    step_3_build_user2host = 'step_3_build_user2host'
    step_4_build_user2clust = 'step_4_build_user2clust'


CLUST_DEFAULT_FEATURES_BUILD_STEPS = (
    ClustFeaturesBuildSteps.step_1_daily_hostnames_extract,
    ClustFeaturesBuildSteps.step_2_bow_reducer,
    ClustFeaturesBuildSteps.step_3_build_user2host,
    ClustFeaturesBuildSteps.step_4_build_user2clust,
)


def remove_old_tables(yt_client, folder, tables_to_keep, pattern=RE_DAILY_LOG_FORMAT):
    tables = []
    assert tables_to_keep >= 0, 'Tables to keep must be nonnegative. Passed {}'.format(tables_to_keep)
    for table in yt_client.list(folder, absolute=True):
        if re.match(pattern, table.split('/')[-1]):
            tables.append(table)
    tables = list(sorted(tables))
    tables_to_remove = tables[:-tables_to_keep]
    for table in tables_to_remove:
        logger.info('[REMOVE OLD TABLES] Remove old table: {}'.format(table))
        yt_client.remove(table)


@yt_wrapper.with_context
class AppendCriptaVectorReducer(object):
    def __call__(self, key, recs, context):
        vector_b = None
        for rec in recs:
            if context.table_index == 0:
                vector_b = rec['vector_b']
                vector_m = rec['vector_m']
            elif vector_b is not None:
                rec['vector_b'] = vector_b
                rec['vector_m'] = vector_m
                yield rec
            else:
                break


def cluster_center_reduce(key, recs):
    def extract_vector(rec):
        vector_b = du.array_fromstring(rec['vector_b'])
        vector_m = du.array_fromstring(rec['vector_m'])
        return np.concatenate((vector_b, vector_m))
    cat = key['cat']
    rec = recs.next()
    out_vector = extract_vector(rec)
    cat_name = rec['cat_name']
    for rec in recs:
        out_vector += extract_vector(rec)
    yield {
        'cat': cat,
        'cat_name': cat_name,
        'vector': du.array_tostring(out_vector)
    }


class AppendClusterCentersMapper(object):
    def __init__(self, centers, norms):
        self.centers = centers
        self.norms = norms

    def __call__(self, rec):
        user_vector = du.array_fromstring(rec['features'])
        cosine = user_vector.dot(self.centers) / (self.norms * np.linalg.norm(user_vector))
        cosine = np.squeeze(cosine).tolist()[0]
        yield {
            'key': rec['key'],
            'features': du.array_tostring(cosine)
        }


class HostnameBowReducer(object):
    def __init__(self, EXTERNAL_ID_KEY):
        self.EXTERNAL_ID_KEY = EXTERNAL_ID_KEY

    def __call__(self, key, recs):
        external_id = key[self.EXTERNAL_ID_KEY]
        bow = {}
        for rec in recs:
            host = rec['host']
            if host in bow:
                bow[host] += rec['counter']
            else:
                bow[host] = rec['counter']
        for host in bow:
            yield {
                self.EXTERNAL_ID_KEY: external_id,
                'host': host,
                'counter': bow[host],
            }


class NormS2vMapper(object):
    def __init__(self, input_key, output_key, features_col='features'):
        self.input_key = input_key
        self.output_key = output_key
        self.features_col = features_col

    def __call__(self, rec):
        raw_s2v = du.array_fromstring(rec[self.features_col])
        normed_s2v = raw_s2v / np.linalg.norm(raw_s2v)
        yield {
            self.output_key: rec[self.input_key],
            self.features_col: du.array_tostring(normed_s2v)
        }


def build_cluster_centers_tables(path_config, yt_client):
    yt_client.run_sort(
        path_config.yandex_catalog_hosts,
        sort_by=['host'],
        spec={
            'title': '[{}] 1 sort yandex catalog hosts'.format(path_config.tag)
        }
    )
    with yt_client.TempTable(path_config.tmp_dir, 'catalog_host_vector_table') as catalog_host_vector_table:
        yt_client.run_reduce(
            AppendCriptaVectorReducer(),
            [
                path_config.external_crypta_host_vectors_table,
                path_config.yandex_catalog_hosts
            ],
            catalog_host_vector_table,
            reduce_by=['host'],
            spec={
                'title': '[{}] 2 append crypta vector reducer'.format(path_config.tag)
            }
        )
        yt_client.run_map_reduce(
            None,
            cluster_center_reduce,
            catalog_host_vector_table,
            path_config.cluster_centers_table,
            reduce_by=['cat'],
            spec={
                'title': '[{}] 3-4 clust vector center reduce [binary]'.format(path_config.tag)
            }
        )


def get_input_extracted_url_tables(path_config, yt_client):
    pattern = '%Y-%m-%d'
    date_time = datetime.strptime(path_config.date, pattern)
    tables_to_process = []
    for day_diff in range(path_config.days_to_take):
        target_date = (date_time - timedelta(days=day_diff)).strftime(pattern)
        table = yt_wrapper.TablePath(yt_wrapper.ypath_join(path_config.extracted_urls_dir, target_date))
        if yt_client.exists(table):
            logger.info(' add table: {}'.format(table))
            tables_to_process.append(table)
        else:
            logger.warn(' table {} not exists'.format(table))
    return tables_to_process


def match_bow_to_cid(input_tables, path_config, yt_client):

    yt_token = yt_wrapper.config['token'] or os.environ.get('YT_TOKEN')
    assert yt_token, '[MATCH BOW TO CID] No YT_TOKEN provided'

    if yt_client.exists(path_config.all_users):
        yt_client.remove(path_config.all_users)
    yt_client.copy(
        path_config.external_cid2yuid_table,
        path_config.all_users,
    )

    logger.info('Start fast yuid_to_cid')
    fast_clust.fast_yuid_to_cid(
        yt_token,
        yt_client.config['proxy']['url'],
        str(path_config.all_users),
        list(map(str, input_tables)),
        str(path_config.cid_bow_table)
    )
    logger.info('Done fast yuid_to_cid')

    yt_client.run_sort(
        path_config.cid_bow_table,
        sort_by=path_config.EXTERNAL_ID_KEY,
        spec=dict(
            title='[{}] sort bow matched by cid'.format(path_config.tag),
            **path_config.cloud_nodes_spec
        )
    )


def bow_reducer(path_config, yt_client):
    if path_config.is_retro:
        bow_reducer_input_tables = [
            path_config.current_extracted_urls_table
        ]
    else:
        input_tables = get_input_extracted_url_tables(path_config, yt_client)
        match_bow_to_cid(input_tables, path_config, yt_client)
        bow_reducer_input_tables = [path_config.cid_bow_table]

    yt_token = yt_wrapper.config['token'] or os.environ.get('YT_TOKEN')
    assert yt_token, '[CLUST FEATURES] Bow Reducer No YT_TOKEN provided'

    rename_columns = '<rename_columns={{{ext_id}=key}}>'.format(
        ext_id=path_config.EXTERNAL_ID_KEY)

    table_exist = False
    for table in bow_reducer_input_tables:
        if yt_utils.check_table_exists(table, yt_client):
            table_exist = True
        else:
            logger.info('!!> table {} not exists'.format(table))

    assert table_exist, Exception('No input table to bow_reducer')

    fast_clust.fast_hostname_bow(
        yt_token,
        yt_client.config['proxy']['url'],
        [rename_columns + str(it) for it in bow_reducer_input_tables],
        str(path_config.external_id2bow)
    )

    yt_client.run_sort(
        path_config.external_id2bow,
        sort_by=['host'],
        spec=dict(
            title='[{}] 6 sort after HostnameBowreducer'.format(path_config.tag),
            **path_config.cloud_nodes_spec
        )
    )

    if yt_client.exists(path_config.cid_bow_table):
        yt_client.remove(path_config.cid_bow_table)
    if not path_config.is_retro and yt_client.exists(path_config.all_users):
        yt_client.remove(path_config.all_users)


def build_user2host_features_table(path_config, yt_client):
    yt_token = yt_wrapper.config['token'] or os.environ.get('YT_TOKEN')
    assert yt_token, '[CLUST FEATURES] Bow Reducer No YT_TOKEN provided'
    with yt_client.TempTable(path_config.tmp_dir, 'scoring_external_id2bow2') as tmp_table:
        logger.info('Start fast_append_host_vector_reducer')
        fast_clust.fast_append_host_vector_reducer(
            yt_token,
            yt_client.config['proxy']['url'],
            str(path_config.external_crypta_host_vectors_table),
            str(path_config.external_id2bow),
            tmp_table)
        logger.info('Done fast_append_host_vector_reducer')
        logger.info('Start fast_user_vector_reducer')
        fast_clust.fast_user_vector_reducer(
            yt_token,
            yt_client.config['proxy']['url'],
            str(tmp_table),
            str(path_config.res_user2host_features)
        )
    logger.info('Done fast_user_vector_reducer')
    yt_client.run_sort(
        path_config.res_user2host_features,
        sort_by='key',
        spec=dict(
            title='[{}] 10 Sort User2Host table'.format(path_config.tag),
            **path_config.cloud_nodes_spec
        )
    )
    yt_client.remove(path_config.external_id2bow)


def build_user2clust_features_table(path_config, yt_client):
    with yt_client.Transaction():
        with yt_client.TempTable() as tmp_table:
            centers, norms = get_cluster_centers_and_norms(path_config, yt_client)
            tmp_clust_features_table = yt_wrapper.TablePath(
                tmp_table,
                schema=[
                    {'name': 'key', 'type': 'string'},
                    {'name': 'features', 'type': 'string'}
                ]
            )
            yt_client.run_map(
                AppendClusterCentersMapper(centers, norms),
                yt_wrapper.TablePath(path_config.res_user2host_features,
                                     columns=['features', 'key']),
                tmp_clust_features_table,
                spec=dict(
                    title='[{}] 12 Build User2Clust'.format(path_config.tag),
                    tmpfs_path='.',
                    copy_files=True,
                    max_failed_job_count=50,
                    mapper={
                        'job_time_limit': 60 * 60 * 1000,
                    },
                    **path_config.cloud_nodes_spec
                )
            )
            yt_client.run_sort(
                '<rename_columns={{key={ext_id}}}>'.format(ext_id=path_config.EXTERNAL_ID_KEY) + str(tmp_clust_features_table),
                path_config.res_user2clust_features,
                sort_by=[path_config.EXTERNAL_ID_KEY],
                spec=dict(
                    title='[{}] 13 Sort user2clust features'.format(path_config.tag),
                    **path_config.cloud_nodes_spec
                )
            )
            yt_client.run_merge(
                path_config.res_user2clust_features,
                path_config.res_user2clust_features,
                spec=dict(
                    title='[{}] 14 Merge chunks for User2Clust'.format(path_config.tag),
                    combine_chunks=True,
                    **path_config.cloud_nodes_spec
                )
            )


def build_normed_s2v_features_table(path_config, yt_client):
    with yt_client.Transaction():
        yt_client.run_map(
            NormS2vMapper(input_key='key', output_key=path_config.EXTERNAL_ID_KEY),
            path_config.res_user2host_features,
            path_config.res_user2normed_s2v_features,
            spec=dict(
                title='[{}] 15 Build Normed s2v features'.format(path_config.tag),
                **path_config.cloud_nodes_spec
            )
        )
        yt_client.run_sort(
            path_config.res_user2normed_s2v_features,
            sort_by=path_config.EXTERNAL_ID_KEY,
            spec=dict(
                title='[{}] 16 Sort Normed s2v features'.format(path_config.tag),
                **path_config.cloud_nodes_spec
            )
        )


def get_cluster_centers_and_norms(path_config, yt_client):
    cat2vector = {}
    for row in yt_client.read_table(path_config.cluster_centers_table):
        cluster_center = list(np.fromstring(row['vector'], sep=' '))
        cat2vector[(row['cat'], row['cat_name'])] = (cluster_center, np.linalg.norm(cluster_center))

    cat2index = path_config.cat2index
    size = len(cat2index)
    centers, norms = [None for i in range(size)], [None for i in range(size)]
    for key in cat2vector:
        cat = key[0]
        center, norm = cat2vector[key]
        centers[cat2index[cat]] = center
        norms[cat2index[cat]] = norm

    centers = np.matrix(centers).T
    norms = np.matrix(norms)
    return centers, norms


def init_config(date, **kwargs):
    path_config = PathConfig(date=date, days_to_take=175, **kwargs)
    yt_client = yt_utils.get_yt_client()
    return yt_client, path_config


def step_1_daily_hostnames_extract(yt_client, date, path_config=None, **kwargs):
    logger.info('Start step 1 daily_hostnames extract')
    if path_config is None:
        _, path_config = init_config(date, **kwargs)
    if path_config.use_cloud_nodes:
        logger.warning('Attention! Using cloud nodes!')
    extract_url.daily_hostnames_extract(path_config, yt_client, date)
    logger.info('End step 1 daily_hostnames extract')


def step_2_bow_reducer(yt_client, date, path_config=None, **kwargs):
    logger.info('Start step 2 bow_reducer')
    if path_config is None:
        _, path_config = init_config(date, **kwargs)
    if path_config.use_cloud_nodes:
        logger.warning('Attention! Using cloud nodes!')
    yt_utils.create_folders([
        path_config.data_dir, path_config.tmp_dir,
        path_config.user2host_dir, path_config.user2clust_dir,
        path_config.user2normed_s2v_weekly_dir,
        path_config.ready_dir],
        yt_client)
    bow_reducer(path_config, yt_client)
    logger.info('End step 2 bow_reducer')


def step_3_build_user2host(yt_client, date, path_config=None, **kwargs):
    logger.info('Start step 3 build user2host')
    if path_config is None:
        _, path_config = init_config(date, **kwargs)
    if path_config.use_cloud_nodes:
        logger.warning('Attention! Using cloud nodes!')
    build_user2host_features_table(path_config, yt_client)
    build_normed_s2v_features_table(path_config, yt_client)
    logger.info('End step 3 build user2 host')


def step_4_build_user2clust(yt_client, date, path_config=None, **kwargs):
    logger.info('Start step 4 build user2clust')
    if path_config is None:
        _, path_config = init_config(date, **kwargs)
    if path_config.use_cloud_nodes:
        logger.warning('Attention! Using cloud nodes!')
    build_user2clust_features_table(path_config, yt_client)
    path_config.collect_garbage(yt_client)
    logger.info('End step 4 build user2clust')


def build_retro_vectors(date, root, yt_client, path_config=None,
                        steps_to_run=CLUST_DEFAULT_FEATURES_BUILD_STEPS):
    path_config = path_config or PathConfig(root=root, date=date, days_to_take=175, is_retro=True)

    if ClustFeaturesBuildSteps.step_1_daily_hostnames_extract in steps_to_run:
        step_1_daily_hostnames_extract(
            yt_client=yt_client, date=date, path_config=path_config)

    if ClustFeaturesBuildSteps.step_2_bow_reducer in steps_to_run:
        step_2_bow_reducer(
            yt_client=yt_client, date=date, path_config=path_config)

    if ClustFeaturesBuildSteps.step_3_build_user2host in steps_to_run:
        step_3_build_user2host(
            yt_client=yt_client, date=date, path_config=path_config)

    if ClustFeaturesBuildSteps.step_4_build_user2clust in steps_to_run:
        step_4_build_user2clust(
            yt_client=yt_client, date=date, path_config=path_config)

    path_config.collect_garbage(yt_client)
