#!/usr/bin/env python
# -*- coding: utf-8 -*-

import datetime
import json
import logging
import os

from dateutil.relativedelta import relativedelta
import numpy as np
import six
from statface_client import StatfaceClient

from yt.yson import get_bytes

from crypta.lib.proto.user_data import (
    user_data_pb2,
    user_data_stats_pb2,
)
from crypta.lib.python import templater
from crypta.lib.python.juggler.juggler_helpers import report_event_to_juggler
import crypta.lib.python.yql.client as yql_helpers
from crypta.lib.python.yt import yt_helpers
from crypta.lookalike.lib.python.utils import (
    fields,
    segment_features_calculator,
    user_features_calculator,
)
from crypta.lookalike.lib.python.utils.config import config
from crypta.lookalike.proto import yt_node_names_pb2
from crypta.siberia.bin.common.yt_describer.py import describe
from crypta.siberia.bin.common.yt_describer.proto.yt_describer_config_pb2 import TYtDescriberConfig


logger = logging.getLogger(__name__)

get_segments_embeddings_query_template = """
PRAGMA File(
    'dssm_lal_model.applier',
    '{{model_path}}');

$model = Dssm::LoadModel(FilePath('dssm_lal_model.applier'), 'doc_embedding');

INSERT INTO `{{segments_dssm_vectors_table}}`
WITH TRUNCATE

SELECT
    GroupID,
    {% for field_name in additional_fields %}
    COALESCE({{field_name}}, '') AS {{field_name}},
    {% endfor %}
    Dssm::Apply(
        $model,
        AsStruct(
            COALESCE(segment_affinitive_sites_ids, '') AS segment_affinitive_sites_ids,
            COALESCE(segment_affinitive_apps, '') AS segment_affinitive_apps,
            COALESCE(segment_float_features, '') AS segment_float_features,
            {% for field_name in additional_dssm_fields %}
            COALESCE({{field_name}}, '') AS {{field_name}},
            {% endfor %}
        ),
        'doc_embedding'
    ) AS segment_vector,
FROM `{{segments_dssm_features_table}}`
"""

get_users_embeddings_query_template = """
PRAGMA File(
    'dssm_lal_model.applier',
    '{{model_path}}');

$model = Dssm::LoadModel(FilePath('dssm_lal_model.applier'), 'query_embedding');

INSERT INTO `{{users_dssm_vectors_table}}`
WITH TRUNCATE

SELECT
    yandexuid,
    {{additional_fields}}
    Dssm::Apply(
        $model,
        AsStruct(
            COALESCE(user_affinitive_sites_ids, '') AS user_affinitive_sites_ids,
            COALESCE(user_affinitive_apps, '') AS user_affinitive_apps,
            COALESCE(user_float_features, '') AS user_float_features,
            {% for field_name in additional_dssm_fields %}
            COALESCE({{field_name}}, '') AS {{field_name}},
            {% endfor %}
        ),
        'query_embedding'
    ) AS user_vector
FROM `{{users_dssm_features_table}}`
"""

get_features_dicts_query_template = """
$data = (
    SELECT
        yuid,
        Geo::RoundRegionById(CAST(Attributes.Region AS Int32), 'city').en_name AS city_name,
        CAST(Attributes.Region AS String) AS main_region_city,
    FROM `{{ user_data_table }}`
    WHERE Attributes.Region IS NOT Null AND Attributes.Region != 0
);

$most_common_cities = (
    (
        SELECT
            main_region_city AS feature,
            SOME(city_name) AS city_name,
            COUNT(*) AS yuids_cnt,
        FROM $data
        GROUP BY main_region_city
        ORDER BY yuids_cnt DESC
        LIMIT {{ TOP_CITIES_NUMBER }}
    )

UNION ALL

    SELECT
        'other' AS feature,
        'other' AS city_name,
        CAST(0 AS Uint64) AS yuids_cnt,
);

INSERT INTO `{{ main_region_city_dict }}`
WITH TRUNCATE

SELECT
    feature,
    city_name,
    ROW_NUMBER() OVER w AS feature_index,
FROM $most_common_cities
WINDOW w AS (ORDER BY yuids_cnt DESC)
ORDER BY feature_index;

$profiles_segments = (
    SELECT
        yandexuid,
        ListMap(DictKeys(Yson::ConvertToUint64Dict(lal_internal)), ($x) -> (CAST($x AS UINT64))) AS lal_internal_segments_list,
    FROM `{{ segments_storage_by_yandexuid }}`
);

$trainable_segments_ids = (
    SELECT DISTINCT trainable_segments
    FROM $profiles_segments
    FLATTEN BY lal_internal_segments_list AS trainable_segments
    WHERE trainable_segments IN ({{ trainable_segments_ids }})
);

INSERT INTO `{{ trainable_segments_dict }}`
WITH TRUNCATE

SELECT
    trainable_segments AS feature,
    ROW_NUMBER() OVER w AS feature_index,
FROM $trainable_segments_ids
WINDOW w AS (ORDER BY trainable_segments)
ORDER BY feature_index;
"""


def describe_segments_in_siberia(yt_client, input_table, output_table, userdata_for_description, transaction):
    segments_description_config = TYtDescriberConfig(
        CryptaIdUserDataTable=userdata_for_description,
        TmpDir=config.COMMON_TMP_DIRECTORY,
        InputTable=input_table,
        OutputTable=output_table,
    )

    yt_helpers.create_empty_table(
        yt_client=yt_client,
        path=output_table,
        schema={
            fields.group_id: 'string',
            fields.stats: 'string',
        },
        additional_attributes={'optimize_for': 'scan'},
        force=True,
    )
    describe(yt_client, transaction, segments_description_config)


def get_segment_dssm_features(yt_client, segments_with_description_table, segment_dssm_features_table, features_mapping):
    yt_helpers.create_empty_table(
        yt_client=yt_client,
        path=segment_dssm_features_table,
        schema={
            fields.group_id: 'string',
            fields.segment_float_features: 'string',
            fields.segment_affinitive_sites_ids: 'string',
            fields.segment_affinitive_apps: 'string',
        },
        additional_attributes={'optimize_for': 'scan'},
        force=True,
    )

    yt_client.run_map(
        MakeDssmSegmentFeaturesMapper(features_mapping={get_bytes(k): v for k, v in features_mapping.items()}),
        segments_with_description_table,
        segment_dssm_features_table,
    )

    yt_client.run_sort(
        segment_dssm_features_table,
        sort_by=fields.group_id,
    )


def get_user_dssm_features(yt_client, user_data_table, output_table, features_mapping, sampling_rate=1.0):
    yt_helpers.create_empty_table(
        yt_client=yt_client,
        path=output_table,
        schema={
            fields.yandexuid: 'uint64',
            fields.cryptaId: 'uint64',
            fields.user_float_features: 'string',
            fields.user_affinitive_sites_ids: 'string',
            fields.user_affinitive_apps: 'string',
        },
        additional_attributes={'optimize_for': 'scan'},
        force=True,
    )

    yt_client.run_map(
        MakeDssmUserFeaturesMapper(
            yandexuid_field_name='yuid',
            features_mapping={get_bytes(k): v for k, v in features_mapping.items()},
        ),
        yt_client.TablePath(
            user_data_table,
            columns=['yuid', 'CryptaID', 'Vectors', 'Segments', 'Attributes', 'Affinities'],
        ),
        output_table,
        spec={
            'title': 'LaL convert UserData to dssm features',
            'job_io': {'table_reader': {'sampling_seed': 1, 'sampling_rate': sampling_rate}},
        },
    )


def get_yt_client(nv_params=None, yt_pool='crypta_lookalike'):
    if nv_params is None:
        return yt_helpers.get_yt_client(
            yt_proxy=config.YT_PROXY,
            yt_pool=yt_pool,
            yt_token=os.environ['YT_TOKEN'],
        )
    return yt_helpers.get_yt_client(
        yt_proxy=str(nv_params['mr-default-cluster']),
        yt_pool=str(nv_params['yt-pool']),
        yt_token=str(nv_params['yt-token']),
        yt_prefix=config.LOOKALIKE_DIRECTORY + '/',
        acl=[
            {
                'subjects': ['crypta-team'],
                'action': 'allow',
                'permissions': ['manage', 'read'],
            },
        ],
        read_parallel={
            'max_thread_count': 20,
            'data_size_per_thread': 8 * 1024 * 1024,
            'enable': True,
        },
        remote_temp_tables_directory=config.COMMON_TMP_DIRECTORY,
    )


def get_yql_client(nv_params=None, yt_pool='crypta_lookalike'):
    if nv_params is None:
        return yql_helpers.create_yql_client(
            yt_proxy=config.YT_PROXY,
            pool=yt_pool,
            token=os.environ['YQL_TOKEN'],
        )
    return yql_helpers.create_yql_client(
        yt_proxy=nv_params['mr-default-cluster'],
        token=nv_params['yql-token'],
        pool=nv_params['yt-pool'],
        syntax_version=1,
        tmp_folder=config.COMMON_YQL_TMP_DIRECTORY,
    )


def get_statface_client(nv_params):
    return StatfaceClient(
        username='robot-unicorn',
        oauth_token=nv_params['stat-token'].strip(),
        host='upload.stat.yandex-team.ru',
    )


def normalize(x):
    square = np.dot(x, x)
    if square == 0:
        return x
    return x / np.sqrt(square)


def get_feature_name(feature_type, value):
    return '{}_{}'.format(feature_type, value)


def get_dssm_entities_from_dir(model_dir):
    yt_node_names = yt_node_names_pb2.TYtNodeNames()

    result_table = os.path.join(model_dir, yt_node_names.UserEmbeddingsTable)
    result_files = [
        os.path.join(model_dir, yt_node_names.DssmModelFile),
        os.path.join(model_dir, yt_node_names.SegmentsDictFile),
    ]
    return result_table, result_files


def get_last_version_of_dssm_entities(yt_client):
    yt_node_names = yt_node_names_pb2.TYtNodeNames()
    versions_path = config.LOOKALIKE_VERSIONS_DIRECTORY

    for version in sorted(yt_client.list(versions_path), reverse=True):
        candidate_version_dir = yt_client.list(os.path.join(versions_path, version))
        if yt_node_names.DssmModelFile in candidate_version_dir and \
                yt_node_names.SegmentsDictFile in candidate_version_dir and \
                yt_node_names.UserEmbeddingsTable in candidate_version_dir:
            return get_dssm_entities_from_dir(os.path.join(versions_path, version))

    return None, []


def copy_last_lal_model(yt_client, output_dir):
    _, result_files = get_last_version_of_dssm_entities(yt_client)
    yt_client.copy(
        result_files[0],
        os.path.join(output_dir, 'dssm_model.applier'),
        recursive=True,
        force=True,
    )
    yt_client.copy(
        result_files[1],
        os.path.join(output_dir, 'segments_dict.json'),
        recursive=True,
        force=True,
    )


def get_old_dssm_model(yt_client, last_date):
    result_files = []
    yt_node_names = yt_node_names_pb2.TYtNodeNames()

    dssm_model_file_name = yt_node_names.DssmModelFile
    segments_dict_file_name = yt_node_names.SegmentsDictFile

    for version in sorted(yt_client.list(config.LOOKALIKE_MONTHLY_VERSIONS_DIRECTORY)):
        if version < last_date:
            continue
        candidate_version_dir = os.path.join(config.LOOKALIKE_MONTHLY_VERSIONS_DIRECTORY, version)
        dir_content = yt_client.list(candidate_version_dir)
        if dssm_model_file_name not in dir_content or segments_dict_file_name not in dir_content:
            continue
        result_files.append(os.path.join(candidate_version_dir, dssm_model_file_name))
        result_files.append(os.path.join(candidate_version_dir, segments_dict_file_name))
        break

    if len(result_files) == 0:
        _, result_files = get_last_version_of_dssm_entities(yt_client)
    return result_files


def get_date_from_past(current_date, days=0, months=0, years=0):
    if isinstance(current_date, six.string_types):
        return str(datetime.datetime.strptime(
            current_date, config.DATE_FORMAT).date() - relativedelta(days=days, months=months, years=years))
    elif isinstance(current_date, datetime.date):
        return current_date - relativedelta(days=days, months=months, years=years)
    elif isinstance(current_date, datetime.datetime):
        return current_date.date() - relativedelta(days=days, months=months, years=years)


class MakeDssmSegmentFeaturesMapper(object):
    def __init__(self, features_mapping):
        self.features_mapping = features_mapping

    def start(self):
        self.user_data_stats = user_data_stats_pb2.TUserDataStats()
        self.features_calculator = segment_features_calculator.TSegmentFeaturesCalculator(self.features_mapping)

    def __call__(self, row):
        self.user_data_stats.ParseFromString(get_bytes(row['Stats']))

        yield {
            fields.group_id: row[fields.group_id],
            fields.segment_float_features: self.features_calculator.PrepareFloatFeatures(self.user_data_stats),
            fields.segment_affinitive_sites_ids: self.features_calculator.PrepareAffinitiveSitesIds(self.user_data_stats),
            fields.segment_affinitive_apps: self.features_calculator.PrepareAffinitiveApps(self.user_data_stats),
        }


class MakeDssmUserFeaturesMapper(object):
    def __init__(self, yandexuid_field_name, features_mapping):
        self.features_mapping = features_mapping
        self.yuid_field_name = yandexuid_field_name

    def start(self):
        self.user_data = user_data_pb2.TUserData()
        self.features_calculator = user_features_calculator.TUserFeaturesCalculator(self.features_mapping)

    def __call__(self, row):
        self.user_data.Vectors.ParseFromString(get_bytes(row['Vectors']))
        attributes = row['Attributes']
        if attributes is not None:
            self.user_data.Attributes.ParseFromString(get_bytes(attributes))
        affinities = row['Affinities']
        if affinities is not None:
            self.user_data.Affinities.ParseFromString(get_bytes(affinities))
        segments = row['Segments']
        if segments is not None:
            self.user_data.Segments.ParseFromString(get_bytes(segments))

        yield {
            fields.yandexuid: int(row[self.yuid_field_name]) if row[self.yuid_field_name] is not None else None,
            fields.cryptaId: int(row['CryptaID']) if row['CryptaID'] is not None else None,
            fields.user_float_features: self.features_calculator.PrepareFloatFeatures(self.user_data),
            fields.user_affinitive_sites_ids: self.features_calculator.PrepareAffinitiveSitesIds(self.user_data),
            fields.user_affinitive_apps: self.features_calculator.PrepareAffinitiveApps(self.user_data),
        }


def calculate_features_dicts(yt_client, yql_client, date, transaction):
    trainable_segments_ids = []
    for row in yt_client.read_table(config.LAB_SEGMENTS_TABLE):
        for export in row['exports']['exports']:
            if 'trainable' in export['tags']:
                trainable_segments_ids.append(str(export['segment_id']))

    get_features_dicts_query = templater.render_template(
        get_features_dicts_query_template,
        vars={
            'main_region_city_dict': os.path.join(config.CATEGORICAL_FEATURES_MATCHING_DIR, 'main_region_city'),
            'user_data_table': config.USER_DATA_TABLE,
            'trainable_segments_dict': os.path.join(config.CATEGORICAL_FEATURES_MATCHING_DIR, 'trainable_segments'),
            'segments_storage_by_yandexuid': config.SEGMENTS_STORAGE_BY_YANDEXUID,
            'trainable_segments_ids': ', '.join(trainable_segments_ids),
            'TOP_CITIES_NUMBER': config.TOP_CITIES_NUMBER,
        },
    )
    yql_client.execute(
        query=get_features_dicts_query,
        transaction=str(transaction.transaction_id),
        title='YQL calculate categorical features matching',
    )

    for dict_type in ('main_region_city', 'trainable_segments'):
        yt_client.set_attribute(os.path.join(config.CATEGORICAL_FEATURES_MATCHING_DIR, dict_type), 'generate_date', date)


def get_features_mapping_dict(yt_client, yql_client, transaction, is_experiment=False):
    if not is_experiment:
        calculate_features_dicts(yt_client, yql_client, str(datetime.date.today()), transaction)

    offset = 0
    features_mapping = dict()

    for segment_type in config.categorical_feature_types:
        for row in yt_client.read_table(os.path.join(config.CATEGORICAL_FEATURES_MATCHING_DIR, segment_type)):
            key = get_feature_name(
                feature_type=config.categorical_feature_name_to_keyword[segment_type],
                value=row[fields.feature],
            )
            features_mapping[key] = row[fields.feature_index] + offset - 1
        offset += yt_client.row_count(os.path.join(config.CATEGORICAL_FEATURES_MATCHING_DIR, segment_type))

    for socdem, segments_cnt in [(fields.gender, config.GENDER_CLASSES_NUM),
                                 (fields.age, config.AGE_CLASSES_NUM),
                                 (fields.income, config.INCOME_CLASSES_NUM)]:
        for value in range(segments_cnt):
            features_mapping[get_feature_name(feature_type=socdem, value=value)] = offset + value
        offset += segments_cnt

    return features_mapping


def write_features_mapping_to_yt(yt_client, features_mapping, features_mapping_table_path):
    yt_helpers.create_empty_table(
        yt_client=yt_client,
        path=features_mapping_table_path,
        schema={
            fields.feature: 'string',
            fields.feature_index: 'uint64',
        },
        force=True,
    )

    yt_client.write_table(
        features_mapping_table_path,
        [{fields.feature: feature, fields.feature_index: feature_index}
         for feature, feature_index in features_mapping.items()],
    )
    yt_client.run_sort(features_mapping_table_path, sort_by=fields.feature_index)
    logger.info('Features mapping has been written to hahn')


def compare_feature_mappings(yt_client, yql_client, features_mapping_table, transaction, is_experiment=False):
    features_mapping = json.loads(next(yt_client.read_file(config.SEGMENTS_DICT_FILE)))

    new_features_mapping = get_features_mapping_dict(yt_client, yql_client, transaction, is_experiment)
    new_set, old_set = set(new_features_mapping.keys()), set(features_mapping.keys())
    deleted_segments, new_segments = old_set - new_set, new_set - old_set

    deleted_cities = list(filter(lambda feature: feature.startswith('city'), deleted_segments))
    new_cities = list(filter(lambda feature: feature.startswith('city'), new_segments))

    if len(deleted_segments) - len(deleted_cities) >= config.DELETED_SEGMENTS_NUMBER_TO_RETRAIN or \
            len(new_segments) - len(new_cities) >= config.NEW_SEGMENTS_NUMBER_TO_RETRAIN or \
            max(len(deleted_cities), len(new_cities)) >= config.TOP_CITIES_UPDATE_NUMBER_TO_RETRAIN:

        logger.info('Full model retrain will be performed due to update of features dict')
        logger.info('New features: ' + ' '.join(new_segments) + '\nDeleted features: ' + ' '.join(deleted_segments))

        report_event_to_juggler(
            status='WARN',
            service='lal_model_dict_update',
            host=config.CRYPTA_ML_JUGGLER_HOST,
            description='Critical update of features dict: need full retrain',
            logger=logger,
        )
        write_features_mapping_to_yt(yt_client, new_features_mapping, features_mapping_table)
        return True, new_features_mapping

    report_event_to_juggler(
        status='OK',
        service='lal_model_dict_update',
        host=config.CRYPTA_ML_JUGGLER_HOST,
        description='Features dict have not been updated',
        logger=logger,
    )

    logger.info('Old model will be used for fine-tuning.')
    return False, features_mapping


def get_production_path(path, nv_params):
    if nv_params is not None and 'working-dir' in nv_params.keys():
        return path.replace(nv_params['working-dir'], config.LOOKALIKE_DIRECTORY)
    return path


def get_lal_model_source_link(released=config.RELEASED, inputs=None, file_name=None):
    if inputs is not None and inputs.has('resource_info'):
        with open(inputs.get('resource_info'), 'r') as resource_info:
            json_data = json.load(resource_info)
            if not file_name:
                file_name = json_data['file_name']
            return os.path.join(
                'https://proxy.sandbox.yandex-team.ru/{}'.format(str(json_data['resource_id'])),
                file_name,
            )
    return "https://proxy.sandbox.yandex-team.ru/last/CRYPTA_LOOK_ALIKE_MODEL/dssm_lal_model.applier?" \
           "attrs={{\"released\":\"{}\"}}".format(released)


def replaced_for_experiment(path):
    return isinstance(path, str) and (
        path.startswith(config.LAL_TRAINING_DIRECTORY) or
        path.startswith(config.LAL_TEST_DIRECTORY) or
        path == config.SEGMENTS_WITH_COUNTS_TABLE
    )


def replace_working_dir(production_path, nv_params):
    return production_path.replace(config.LOOKALIKE_DIRECTORY, nv_params['working-dir'])


def update_config_for_experiment(nv_params):
    config.__dict__.update({
        key: replace_working_dir(path, nv_params)
        for key, path in config.__dict__.items() if replaced_for_experiment(path)
    })


def get_additional_dssm_fields(yt_client, base_path, extended_table_path):
    base_fields = set([field['name'] for field in yt_client.get_attribute(base_path, 'schema')] + ['features'])
    schema = yt_client.get_attribute(extended_table_path, 'schema')
    return [x['name'] for x in schema if x['name'] not in base_fields]
