#!/usr/bin/env python
# -*- coding: utf-8 -*-

import functools
import json
import os

import numpy as np

from crypta.lib.python import templater
from crypta.lookalike.lib.python.utils import (
    utils as lal_utils,
    fields as lal_fields,
)
from crypta.profile.utils.config import config
from crypta.profile.utils.luigi_utils import (
    ExternalInput,
    ExternalInputDate,
)
from crypta.profile.utils.segment_utils.builders import RegularSegmentBuilder
from crypta.siberia.bin.common.yt_describer.py import describe
from crypta.siberia.bin.common.yt_describer.proto.yt_describer_config_pb2 import TYtDescriberConfig


RELEASE_TYPE = 'stable'
SEGMENTS_FOR_LAL = {
    1341: 'soda_lal',
    1167: 'baby_toys_lal',
    1164: 'goods_for_children_lal',
    1166: 'food_for_kids_lal',
    1170: 'coffee_lal',
    1171: 'cosmetics_mass_market_lal',
    1344: 'low-alcohol_lal',
    1346: 'juice_lal',
    1168: 'supplies_for_cats_lal',
    1169: 'supplies_for_dogs_lal',
    1172: 'household_chemicals_lal',
    1340: 'water_lal',
    1342: 'yogurt_lal',
    1343: 'diary_lal',
    1351: 'premium_quality_juice_lal',
    1173: 'sweets_lal',
    1345: 'snacks_lal',
    2099: 'shaving_products_lal',
    2098: 'hair_products_lal',
    2090: 'washing_mashine_lal',
    2100: 'laundry_products_lal',
    1347: 'cheese_lal',
    1348: 'ice_tea_lal',
    1248: 'tea_lal',
    1349: 'chips_lal',
    1350: 'energy_drink_lal',
    1165: 'diapers_lal',
}

get_edadeal_user_data_query_template = """
$profiles_heuristic = (
    SELECT CAST(yandexuid AS String) AS yuid,
        Yson::ConvertToUint64List(heuristic_common) AS heuristic_common
    FROM `{profiles_for_14days_table}`
);

$profiles_heuristic_flattened = (
    SELECT yuid,
        segment_id
    FROM $profiles_heuristic
    FLATTEN BY heuristic_common AS segment_id
);

INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT
    user_data.yuid AS IdValue,
    'yandexuid' AS IdType,
    Cast(profiles_heuristic.segment_id AS String) AS GroupID
FROM `{user_data_table}` VIEW raw AS user_data
INNER JOIN $profiles_heuristic_flattened AS profiles_heuristic
USING(yuid)
WHERE profiles_heuristic.segment_id in ({needed_segments});
"""

select_lal_users_query_template = """
$edadeal_heuristic_cnt = (
    SELECT CAST(GroupID AS Uint64) AS segment_id,
        COUNT(*) AS yuids_cnt
    FROM `{user_data_edadeal_table}`
    GROUP BY GroupID
);

$max_edadeal_heuristic_cnt =(
    SELECT MAX(yuids_cnt) FROM $edadeal_heuristic_cnt
);

$lal_by_user = (
    SELECT yandexuid,
        segment_id, distance,
        ROW_NUMBER() OVER w AS segment_rank
    FROM `{dssm_lal_distances_table}`
    WINDOW w AS (
        PARTITION BY yandexuid
        ORDER BY distance
    )
);

$lal_by_top3_segments = (
    SELECT yandexuid,
        segment_id,
        distance,
        ROW_NUMBER() OVER w AS user_rank
    FROM $lal_by_user
    WHERE segment_rank <= {top_segments_for_user_cnt}
    WINDOW w AS (
        PARTITION BY segment_id
        ORDER BY distance
    )
);

$heuristic_ids_to_lal_segment_name = AsDict(
{segment_id_to_name}
);

INSERT INTO `{output_table}`
WITH TRUNCATE

SELECT CAST(lal.yandexuid AS String) AS id,
    'yandexuid' AS id_type,
    $heuristic_ids_to_lal_segment_name[lal.segment_id] AS segment_name
FROM $lal_by_top3_segments AS lal
INNER JOIN $edadeal_heuristic_cnt AS heuristic_counts
USING(segment_id)
WHERE lal.user_rank <= {max_volume} * heuristic_counts.yuids_cnt / $max_edadeal_heuristic_cnt;
"""


def calculate_dssm_lal_scores(row, segments):
    user_vector = lal_utils.normalize(row['embedding'])

    for segment in segments:
        yield {
            lal_fields.yandexuid: row['user_id'],
            'segment_id': int(segment[lal_fields.group_id]),
            lal_fields.distance: 1.0 - np.dot(segment[lal_fields.vector], user_vector),
        }


class OfflinePurchasesLookAlike(RegularSegmentBuilder):
    keyword = 547

    name_segment_dict = {
        'soda_lal': 1632,
        'baby_toys_lal': 1633,
        'goods_for_children_lal': 1634,
        'food_for_kids_lal': 1635,
        'coffee_lal': 1636,
        'cosmetics_mass_market_lal': 1637,
        'low-alcohol_lal': 1638,
        'juice_lal': 1639,
        'supplies_for_cats_lal': 1640,
        'supplies_for_dogs_lal': 1921,
        'household_chemicals_lal': 2201,
        'water_lal': 2202,
        'yogurt_lal': 2204,
        'diary_lal': 2205,
        'premium_quality_juice_lal': 2206,
        'sweets_lal': 2207,
        'snacks_lal': 2208,
        'shaving_products_lal': 2209,
        'hair_products_lal': 2210,
        'washing_mashine_lal': 2211,
        'laundry_products_lal': 2212,
        'cheese_lal': 2213,
        'ice_tea_lal': 2214,
        'tea_lal': 2215,
        'chips_lal': 2216,
        'energy_drink_lal': 2217,
        'diapers_lal': 2203,
    }

    def requires(self):
        return {
            'profiles_for_14days': ExternalInputDate(
                table=config.YANDEXUID_EXPORT_PROFILES_14_DAYS_TABLE,
                date=self.date,
            ),
            'user_data': ExternalInputDate(
                table=config.USER_DATA_TABLE,
                date=self.date,
                field='_last_update_date',
            ),
            'user_embeddings': ExternalInputDate(
                table=os.path.join(
                    max(self.yt.list(config.LOOKALIKE_VERSIONS_DIRECTORY, absolute=True)),
                    'user_embeddings',
                ),
                date=self.date,
                field='_last_update_date',
            ),
            'features_mapping': ExternalInput(
                table=os.path.join(
                    max(self.yt.list(config.LOOKALIKE_VERSIONS_DIRECTORY, absolute=True)),
                    'segments_dict.json',
                ),
            ),
        }

    def get_segments_vectors(self, table):
        segments = []
        for row in self.yt.read_table(table):
            segments.append({
                lal_fields.group_id: row[lal_fields.group_id],
                lal_fields.vector: lal_utils.normalize(row['segment_vector']),
            })
        return segments

    def build_segment(self, inputs, output_path):
        with self.yt.Transaction() as transaction, \
                self.yt.TempTable() as user_data_edadeal_table, \
                self.yt.TempTable() as user_data_stats_edadeal_table, \
                self.yt.TempTable() as edadeal_segments_dssm_features_table, \
                self.yt.TempTable() as segments_dssm_vectors_table, \
                self.yt.TempTable() as dssm_lal_distances_table:
            self.yql.query(
                query_string=get_edadeal_user_data_query_template.format(
                    needed_segments=', '.join(map(str, SEGMENTS_FOR_LAL.keys())),
                    profiles_for_14days_table=inputs['profiles_for_14days'].table,
                    user_data_table=inputs['user_data'].table,
                    output_table=user_data_edadeal_table,
                ),
                transaction=transaction,
                title='Get edadeal users for LaL',
            )

            self.yt.create_empty_table(
                user_data_stats_edadeal_table,
                schema={
                    lal_fields.group_id: 'string',
                    'Stats': 'string',
                },
            )

            segments_description_config = TYtDescriberConfig(
                CryptaIdUserDataTable=config.FOR_DESCRIPTION_BY_CRYPTAID_TABLE,
                TmpDir=config.PROFILES_TMP_YT_DIRECTORY,
                InputTable=user_data_edadeal_table,
                OutputTable=user_data_stats_edadeal_table,
            )

            describe(self.yt, transaction, segments_description_config)

            features_mapping = json.loads(next(self.yt.read_file(inputs['features_mapping'].table)))

            self.yt.create_empty_table(
                path=edadeal_segments_dssm_features_table,
                schema={
                    lal_fields.group_id: 'string',
                    lal_fields.segment_float_features: 'string',
                    lal_fields.segment_affinitive_sites_ids: 'string',
                    lal_fields.segment_affinitive_apps: 'string',
                },
            )

            self.yt.run_map(
                lal_utils.MakeDssmSegmentFeaturesMapper(features_mapping=features_mapping),
                user_data_stats_edadeal_table,
                edadeal_segments_dssm_features_table,
                spec={'title': 'Make DSSM Segment Features'},
            )
            query_string = templater.render_template(
                lal_utils.get_segments_embeddings_query_template,
                vars={
                    'model_path': lal_utils.get_lal_model_source_link(released=RELEASE_TYPE),
                    'segments_dssm_features_table': edadeal_segments_dssm_features_table,
                    'segments_dssm_vectors_table': segments_dssm_vectors_table,
                }
            )
            self.yql.query(
                query_string=query_string,
                transaction=transaction,
                title='Get edadeal segments embeddings for LaL',
            )

            segments = self.get_segments_vectors(segments_dssm_vectors_table)

            self.yt.create_empty_table(
                path=dssm_lal_distances_table,
                schema={
                    lal_fields.yandexuid: 'uint64',
                    'segment_id': 'uint64',
                    lal_fields.distance: 'double',
                },
            )

            self.yt.run_map(
                functools.partial(calculate_dssm_lal_scores, segments=segments),
                self.yt.TablePath(
                    inputs['user_embeddings'].table,
                    columns=('user_id', 'embedding'),
                ),
                dssm_lal_distances_table,
                spec={'title': 'Calculate DSSM LaL scores for edadeal segments'},
            )
            self.yql.query(
                query_string=select_lal_users_query_template.format(
                    top_segments_for_user_cnt=4,
                    max_volume=50e6,
                    user_data_edadeal_table=user_data_edadeal_table,
                    dssm_lal_distances_table=dssm_lal_distances_table,
                    segment_id_to_name=',\n'.join(["AsTuple({}, '{}')".format(segment_id, name) for segment_id, name in
                                                   SEGMENTS_FOR_LAL.iteritems()]),
                    output_table=output_path,
                ),
                transaction=transaction,
                title='Select result users for edadeal LaL',
            )
