import os
import sys
if sys.version_info[0] < 3:
    from StringIO import StringIO
else:
    from io import StringIO

import numpy as np
import pandas as pd

from crypta.lib.python import templater
from crypta.lib.python.retryable_http_client import RetryableHttpClient
from crypta.lib.python.custom_ml import training_config
from crypta.lib.python.custom_ml.tools import utils
from crypta.lib.python.tvm.helpers import get_tvm_headers
from crypta.lib.python.yt import yt_helpers

normalized_ids_query_template = """
PRAGMA File('libcrypta_identifier_udf.so', '{{ libcrypta_identifier_udf }}');
PRAGMA Udf('libcrypta_identifier_udf.so');

{% if are_ids_flattened %}

$normalized_ids = (
    SELECT
        Identifiers::Normalize(id_type, String::Strip(id)) AS id,
        id_type,
        {% if with_retro_date %}
        retro_date,
        {% else %}
        RANDOM(id) AS retro_date,
        {% endif %}
        segment_name,
    FROM `{{ raw_sample_table }}`
    WHERE Identifiers::Normalize(id_type, String::Strip(id)) IS NOT NULL
);

{% else %}

$normalized_ids = (
    {% for id_type, column_name in id_type_to_column %}
    {% if loop.index != 1 %}
        UNION ALL
    {% endif %}

    SELECT
        Identifiers::Normalize('{{ id_type }}', String::Strip({{ column_name }})) AS id,
        '{{ id_type }}' AS id_type,
        retro_date,
        CASE CAST(target AS Double)
            WHEN 1 THEN 'positive'
            WHEN 0 THEN 'negative'
            ELSE NULL
        END AS segment_name,
    FROM `{{ raw_sample_table }}`
    WHERE Identifiers::Normalize('{{ id_type }}', String::Strip({{ column_name }})) IS NOT NULL
    {% endfor %}
);

{% endif %}

INSERT INTO `{{ normalized_ids_table }}`
WITH TRUNCATE

SELECT *
FROM $normalized_ids;

INSERT INTO `{{ ids_stats_table }}`
WITH TRUNCATE

SELECT
    id_type,
    COUNT(DISTINCT id) AS cnt,
FROM $normalized_ids
GROUP BY id_type;

INSERT INTO `{{ classes_stats_table }}`
WITH TRUNCATE

SELECT
    segment_name,
    COUNT(DISTINCT id) AS cnt,
FROM $normalized_ids
GROUP BY segment_name;
"""

matching_query_template = """
$pre_puid_matching = (
    SELECT
        id,
        id_type,
        segment_name,
        retro_date,
        id_type AS id_type_source,
    FROM `{{normalized_ids_table}}`
    WHERE id_type != 'duid'
    {% if add_duid %}
        UNION ALL
    SELECT
        CAST(yandexuid AS String) AS id,
        'yandexuid' AS id_type,
        'duid' AS id_type_source,
        retro_date,
        segment_name,
    FROM `{{normalized_ids_table}}` AS data
    INNER JOIN `{{duid_matching}}` AS matching
    ON data.id == CAST(matching.duid AS String)
    WHERE id_type == 'duid'
    {% endif %}
);

$matched_with_puid = (
    SELECT
        data.id AS id,
        data.id_type AS id_type,
        id_type_source,
        data.id AS puid,
        retro_date,
        segment_name,
    FROM $pre_puid_matching AS data
    WHERE id_type == 'puid'
{% for matching_table in matching_tables %}
        UNION ALL
    SELECT
        data.id AS id,
        data.id_type AS id_type,
        id_type_source,
        target_id AS puid,
        retro_date,
        segment_name,
    FROM $pre_puid_matching AS data
    INNER JOIN `{{matching_table}}` AS matching
    ON data.id == matching.id AND data.id_type == matching.id_type
{% endfor %}
);

INSERT INTO `{{matching_stats_table}}`
WITH TRUNCATE

SELECT
    id_type_source AS id_type,
    COUNT(DISTINCT id) AS matched_cnt,
FROM $matched_with_puid
GROUP BY id_type_source;

$new_sample = (
    SELECT
        puid AS id,
        'puid' AS id_type,
        MAX_BY(segment_name, retro_date) AS segment_name,
        {% if with_retro_date %}
        MAX(retro_date) AS retro_date,
        {% endif %}
    FROM $matched_with_puid
    GROUP BY puid
);

INSERT INTO `{{sample_by_puid_table}}`
WITH TRUNCATE

SELECT *
FROM $new_sample;

$new_sample_size = SELECT CAST(COUNT(*) AS Double) FROM $new_sample;

INSERT INTO `{{classes_stats_table}}`
WITH TRUNCATE

SELECT
    segment_name,
    COUNT(*) AS cnt,
    COUNT(*) / $new_sample_size AS ratio,
FROM $new_sample
GROUP BY segment_name;
"""


def normalize_ids(yt_client, yql_client, raw_sample_table, normalized_ids_table, with_retro_date, crypta_identifier_udf_url, transaction_id):
    with yt_client.TempTable() as ids_stats_table, yt_client.TempTable() as classes_stats_table:
        sample_table_schema = yt_helpers.get_yt_schema_dict_from_table(yt_client, raw_sample_table)
        are_ids_flattened = True if ('id' in sample_table_schema and 'id_type' in sample_table_schema) else False

        id_type_to_column_filtered = {}
        for id_type, column_name in training_config.id_type_to_column.items():
            if column_name in sample_table_schema:
                id_type_to_column_filtered[id_type] = column_name

        normalized_ids_query = templater.render_template(
            template_text=normalized_ids_query_template,
            vars={
                'are_ids_flattened': are_ids_flattened,
                'with_retro_date': with_retro_date,
                'id_type_to_column': id_type_to_column_filtered.items(),
                'raw_sample_table': raw_sample_table,
                'normalized_ids_table': normalized_ids_table,
                'ids_stats_table': ids_stats_table,
                'classes_stats_table': classes_stats_table,
                'libcrypta_identifier_udf': crypta_identifier_udf_url,
            },
        )

        yql_client.execute(
            query=normalized_ids_query,
            title='YQL normalize ids and get id_types stats from raw_sample',
            transaction=str(transaction_id),
        )

        id_types_table = pd.DataFrame(list(yt_client.read_table(ids_stats_table)))
        classes_stats_table = pd.DataFrame(list(yt_client.read_table(classes_stats_table)))
        assert len(id_types_table) > 0, 'Empty raw_sample table'

        # check that there is sufficient number of normalized ids
        unique_ids_count = id_types_table['cnt'].sum()
        if os.environ.get('CRYPTA_ENVIRONMENT') == 'stable':
            assert unique_ids_count > training_config.MIN_UNIQUE_IDS_CNT, \
                'The number of unique ids {} in raw_sample is too small.'.format(unique_ids_count)
            for _, row in classes_stats_table.iterrows():
                assert row['cnt'] > training_config.MIN_IDS_FOR_CLASS_CNT, \
                    'Ids number for {} class is too small = {}'.format(row['segment_name'], row['cnt'])

        return id_types_table


def match_ids_with_puid(yt_client, yql_client, id_types, normalized_ids_table, sample_by_puid_table, with_retro_date, logger, transaction_id):
    with yt_client.TempTable() as matching_stats_table, yt_client.TempTable() as classes_stats_table:
        add_duid = False
        matching_tables = []
        for id_type in id_types:
            matching_table = training_config.MATCHING_TABLE_TEMPLATE.format(id_type)
            # convert id to puid if matching table exists
            if yt_client.exists(matching_table):
                matching_tables.append(matching_table)
            # convert duid to yandexuid first, then convert yandexuid to puid
            elif id_type == 'duid':
                add_duid = True
                matching_table = training_config.MATCHING_TABLE_TEMPLATE.format('yandexuid')
                matching_tables.append(matching_table)
            elif id_type != 'puid':
                logger.warning('Unsupported id_type: {}'.format(id_type))

        matching_query = templater.render_template(
            template_text=matching_query_template,
            vars={
                'matching_tables': matching_tables,
                'add_duid': add_duid,
                'duid_matching': training_config.DUID_MATCHING,
                'normalized_ids_table': normalized_ids_table,
                'with_retro_date': with_retro_date,
                'sample_by_puid_table': sample_by_puid_table,
                'matching_stats_table': matching_stats_table,
                'classes_stats_table': classes_stats_table,
            },
        )

        yql_client.execute(
            query=matching_query,
            title='YQL get id-puid matching',
            transaction=str(transaction_id),
        )

        matching_stats_df = pd.DataFrame(list(yt_client.read_table(matching_stats_table)))
        assert len(matching_stats_df) > 0, 'No matched ids.'
        classes_stats_df = pd.DataFrame(list(yt_client.read_table(classes_stats_table)))

        return matching_stats_df, classes_stats_df


def prepare_sample_by_puid(
    yt_client,
    yql_client,
    raw_sample_table,
    sample_by_puid_table,
    crypta_identifier_udf_url,
    logger,
):
    with yt_client.Transaction() as transaction, yt_client.TempTable() as normalized_ids_table:
        table_columns = validate_input_table(
            yt_client=yt_client,
            raw_sample_table=raw_sample_table,
        )
        with_retro_date = 'retro_date' in table_columns

        initial_id_types_table = normalize_ids(
            yt_client=yt_client,
            yql_client=yql_client,
            raw_sample_table=raw_sample_table,
            normalized_ids_table=normalized_ids_table,
            with_retro_date=with_retro_date,
            crypta_identifier_udf_url=crypta_identifier_udf_url,
            transaction_id=transaction.transaction_id,
        )
        id_types = set(initial_id_types_table['id_type'].values)

        matching_stats_table, classes_stats_table = match_ids_with_puid(
            yt_client=yt_client,
            yql_client=yql_client,
            id_types=id_types,
            normalized_ids_table=normalized_ids_table,
            sample_by_puid_table=sample_by_puid_table,
            with_retro_date=with_retro_date,
            logger=logger,
            transaction_id=transaction.transaction_id,
        )

        stats_df = pd.merge(initial_id_types_table, matching_stats_table, on='id_type', how="outer", indicator=True)
        stats_df['ratio'] = stats_df['matched_cnt'] / stats_df['cnt']
        for _, row in stats_df.iterrows():
            if np.isnan(row['ratio']) or row['ratio'] < 0.1:
                logger.warning('Matching ratio for id_type {} is too low'.format(row['id_type']))

    return stats_df[['id_type', 'cnt', 'matched_cnt', 'ratio']], classes_stats_table


def save_sample_df_to_table(yt_client, sample_df, raw_sample_table):
    filtered_columns = []
    for column in sample_df.columns:
        if column in training_config.columns_for_classification:
            sample_df[column] = sample_df[column].astype(str)
            filtered_columns.append(column)

    columns = ', '.join(sample_df.columns)
    assert 'target' in filtered_columns and 'retro_date' in filtered_columns, \
        'Sample must contain "target" and "retro_date" columns. Current columns: {}'.format(columns)
    assert len(filtered_columns) > 2, 'Sample must contain at least one column with identifiers. Current columns: {}'.format(columns)

    yt_helpers.create_empty_table(
        yt_client,
        raw_sample_table,
        schema={column: 'string' for column in sample_df.columns},
    )

    yt_client.write_table(
        yt_client.TablePath(raw_sample_table, append=True),
        sample_df.to_dict('records'),
    )


def prepare_training_sample_table_from_file(yt_client, raw_sample_file, raw_sample_table):
    sample_df = pd.read_csv(yt_client.read_file(raw_sample_file))
    save_sample_df_to_table(yt_client, sample_df, raw_sample_table)


def prepare_training_sample_table_from_audience(yt_client, audience_ids, raw_sample_table, logger):
    http_client = RetryableHttpClient(tries_count=3, delay=3, jitter=0, timeout_in_secs=600, logger=logger)
    tvm_ticket = utils.get_tvm_client(tvm_id=training_config.PROFILE_TVM_ID, tvm_secret_name='PROFILE_TVM_SECRET').get_service_ticket_for('audience')

    audience_ids = [segment.strip() for segment in str(audience_ids).split(',')]

    output_tables = []
    for idx, audience_id in enumerate(audience_ids):
        response = http_client._make_get_request(
            'http://audience-intapid.metrika.yandex.ru:9099/crypta/uploading_segment_data',
            params={'segment_id': audience_id},
            headers=get_tvm_headers(tvm_ticket),
        )

        sample_df = pd.read_csv(StringIO(response.text), index_col=None)
        logger.info('file: {}\ncolumns: {}\nsample:\n{}\n'.format(audience_id, sample_df.columns, response.text[:256]))

        if len(audience_ids) == 1:
            save_sample_df_to_table(yt_client, sample_df, raw_sample_table)
        else:
            output_table = '{}_part_{}'.format(raw_sample_table, idx)
            output_tables.append(output_table)
            save_sample_df_to_table(yt_client, sample_df, output_table)

    if len(audience_ids) > 1:
        yt_client.run_merge(
            output_tables,
            raw_sample_table,
        )


def check_table_columns_custom_format(columns):
    return 'id' in columns and 'id_type' in columns and 'segment_name' in columns


def check_table_columns_main_format(columns):
    if 'target' not in columns or 'retro_date' not in columns:
        return False

    for id_type in training_config.id_type_to_column.values():
        if id_type in columns:
            return True

    return False


def validate_input_table(yt_client, raw_sample_table):
    table_columns = yt_helpers.get_yt_schema_dict_from_table(
        yt_client,
        raw_sample_table,
    )

    is_valid = check_table_columns_custom_format(table_columns) or check_table_columns_main_format(table_columns)
    assert is_valid, 'Columns set in train sample is not right: {}'.format(', '.join(table_columns))

    return table_columns
