#!/usr/bin/env python
# -*- coding: utf-8 -*-

from datetime import datetime
import logging

import numpy as np

from yt.wrapper import with_context

from crypta.lib.python.yql import yql_helpers
from crypta.lib.python.yt import yt_helpers

from crypta.lookalike.lib.python.utils import fields
from crypta.lookalike.lib.python.utils.mobile_config import config as mobile_config
from crypta.lookalike.lib.python.utils.utils import (
    get_date_from_past,
    normalize,
)

logger = logging.getLogger(__name__)


DISTANCES_PROMOTED_APPS_TABLE_IDX = 0
DISTANCES_TOP_APPS_TABLE_IDX = 1
DISTANCES_CATEGORIES_TABLE_IDX = 2

APPS_AND_CATEGORIES_SCORES_TABLE_IDX = 0
INSTALLERS_BY_AD_SCORES_TABLE_IDX = 1


get_users_by_crypta_id_query = """
$matching = (
    SELECT
        id,
        id_type,
        target_id AS cryptaId,
    FROM `{gaid_to_crypta_id_matching}`
UNION ALL
    SELECT
        id,
        id_type,
        target_id AS cryptaId,
    FROM `{idfa_to_crypta_id_matching}`
);

$matched_with_crypta_id = (
    SELECT
        apps_by_devid_daily.id AS id,
        apps_by_devid_daily.id_type AS id_type,
        apps,
        cryptaId,
    FROM `{apps_by_devid}` AS apps_by_devid_daily
    LEFT JOIN $matching AS devid_crypta_id
    USING (id, id_type)
    {additional_options}
);

$lowercase_matching = (
    SELECT DISTINCT
        BundleId AS app_id,
        String::AsciiToLower(BundleId) AS app_id_lc,
        IF(SourceID == 1, 'gaid', 'idfa') AS id_type,
    FROM `{app_data}`
);

$matched_with_app_id = (
    SELECT
        id,
        data.id_type AS id_type,
        cryptaId,
        lowercase_matching.app_id AS app_id,
    FROM $matched_with_crypta_id AS data
    FLATTEN LIST BY (Yson::ConvertToStringList(data.apps) AS app_id)
    INNER JOIN $lowercase_matching AS lowercase_matching
    ON data.app_id == lowercase_matching.app_id_lc AND data.id_type == lowercase_matching.id_type
);

INSERT INTO `{devid_by_app_with_crypta_id}`
WITH TRUNCATE

SELECT *
FROM $matched_with_app_id;

INSERT INTO `{apps_segments}`
WITH TRUNCATE

SELECT
    app_id || '__' || id_type AS GroupID,
    cryptaId AS IdValue,
    'crypta_id' AS IdType,
FROM $matched_with_app_id
WHERE cryptaId IS NOT NULL;

INSERT INTO `{crypta_id_to_id_type}`
WITH TRUNCATE

SELECT
    id_type,
    cryptaId,
FROM $matched_with_app_id
GROUP BY id_type, cryptaId;
"""


filter_top_and_promoted = """
INSERT INTO `{filtered_vectors_table}`
WITH TRUNCATE

SELECT
    apps.app_id AS app_id,
    apps.id_type AS id_type,
    apps.game AS game,
    Digest::Md5HalfMix((apps.app_id ?? '') || IF(apps.id_type == 'gaid', '-Google Play', '-App Store')) AS MD5Hash,
    vectors.segment_vector AS vector,
FROM `{vectors_table}` AS vectors
INNER JOIN `{info_table}` AS apps
ON apps.app_id || '__' || apps.id_type == vectors.GroupID;
"""


def get_yt_client(nv_params):
    return yt_helpers.get_yt_client_from_nv_parameters(nv_params, mobile_config.COMMON_TMP_DIRECTORY)


def get_yql_client(nv_params):
    return yql_helpers.get_yql_client_from_nv_parameters(nv_params, mobile_config.COMMON_YQL_TMP_DIRECTORY)


def vector_row_to_features(row, vector_field=fields.vector):
    vector = np.fromstring(row[vector_field], '<f4')
    return vector / np.sqrt(np.dot(vector, vector))


def get_app2vec(row, category_to_vector, vector_field=fields.vector):
    app2vec = None
    if row[vector_field] is None and row[fields.category] is not None \
            and row[fields.category].lower() in category_to_vector:
        app2vec = np.array(category_to_vector[row[fields.category].lower()])
    elif row[vector_field] is not None:
        app2vec = vector_row_to_features(row, vector_field)

    return app2vec


def segments_lal_scores_mapper(row, segments):
    for segment in segments:
        result_row = {
            fields.id_type: row[fields.id_type],
            fields.cryptaId: str(row[fields.cryptaId]),
            fields.group_id: segment[fields.group_id],
            fields.distance: 1. - np.dot(normalize(segment[fields.vector]), normalize(row[fields.user_vector])),
        }

        yield result_row


def categories_lal_scores_mapper(row, segments):
    for segment in segments:
        if segment[fields.id_type] != row[fields.id_type]:
            continue

        yield {
            fields.id_type: row[fields.id_type],
            fields.cryptaId: str(row[fields.cryptaId]),
            fields.cluster_id: segment[fields.cluster_id],
            fields.distance: 1. - np.dot(normalize(segment[fields.vector]), normalize(row[fields.user_vector])),
        }


def apps_lal_scores_mapper(row, segments):
    target_segments = row[fields.apps]

    for segment in segments:
        if segment[fields.id_type] != row[fields.id_type]:
            continue

        yield {
            fields.id_type: row[fields.id_type],
            fields.cryptaId: str(row[fields.cryptaId]),
            fields.app_id: segment[fields.app_id],
            fields.distance: 1. - np.dot(normalize(segment[fields.vector]), normalize(row[fields.user_vector])),
            fields.label: 1 if segment[fields.app_id] in target_segments else 0,
            fields.game: segment[fields.game],
            fields.MD5Hash: segment[fields.MD5Hash],
        }


def get_apps(yt, yql, vectors_table_path, info_table_path, transaction):
    with yt.TempTable() as top_and_promoted_apps:
        yql.execute(
            query=filter_top_and_promoted.format(
                filtered_vectors_table=top_and_promoted_apps,
                vectors_table=vectors_table_path,
                info_table=info_table_path,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL Filter top/promoted apps vectors',
        )
        segments = list(yt.read_table(top_and_promoted_apps))
        return segments


def get_top_by_scores(scores, limit):
    top_ids_with_scores = {}
    for score, id in sorted(scores, reverse=True)[:limit]:
        top_ids_with_scores[str(id)] = score
    return top_ids_with_scores


def installers_segments_scores_reducer(key, records):
    segments_scores = {}

    for row in records:
        segments_scores[row[fields.group_id]] = 1 - row[fields.distance]

    yield {
        fields.id_type: key[fields.id_type],
        fields.cryptaId: key[fields.cryptaId],
        fields.installs_by_ads_scores: segments_scores,
    }


@with_context
def apps_scores_reducer(key, records, context):
    game_scores = []
    non_game_scores = []
    promoted_apps_scores = []
    categories_scores = []

    for row in records:
        if context.table_index == DISTANCES_CATEGORIES_TABLE_IDX:
            categories_scores.append((1 - row[fields.distance], row[fields.cluster_id]))
        elif row[fields.label] == 0:
            score_with_app_hash = (1 - row[fields.distance], row[fields.MD5Hash])
            if context.table_index == DISTANCES_PROMOTED_APPS_TABLE_IDX:
                promoted_apps_scores.append(score_with_app_hash)
            if row[fields.game]:
                game_scores.append(score_with_app_hash)
            else:
                non_game_scores.append(score_with_app_hash)

    top_common_lal_apps = get_top_by_scores(game_scores, mobile_config.GAME_APPS_CNT)
    top_common_lal_apps.update(get_top_by_scores(non_game_scores, mobile_config.NON_GAME_APPS_CNT))
    top_promoted_apps = get_top_by_scores(promoted_apps_scores, mobile_config.RECOMMENDED_PROMOTED_APPS_CNT)
    top_categories = get_top_by_scores(categories_scores, mobile_config.TOP_COMMON_CATEGORIES_CNT)

    yield {
        fields.id_type: key[fields.id_type],
        fields.cryptaId: key[fields.cryptaId],
        fields.top_common_lal_apps: top_common_lal_apps,
        fields.promoted: top_promoted_apps,
        fields.top_common_lal_categories: top_categories,
    }


def get_date_from_nv_parameters(nv_params):
    return datetime.fromtimestamp(int(nv_params['timestamp'])).strftime('%Y-%m-%d')


def check_date(yt_client, table_path, nv_params, gap_days=0):
    target_date = get_date_from_nv_parameters(nv_params)

    if yt_client.exists(table_path):
        generate_date = yt_client.get_attribute(
            table_path,
            'generate_date',
            None,
        )
        if generate_date is not None and get_date_from_past(target_date, days=gap_days) <= generate_date[:10]:
            return True

    return False


def set_generate_date(yt_client, table_path, nv_params):
    target_date = get_date_from_nv_parameters(nv_params)

    yt_client.set_attribute(
        path=table_path,
        attribute='generate_date',
        value=target_date,
    )
