import os
import logging

from google.protobuf import json_format
import numpy as np
import pandas as pd

from crypta.lib.python import classification_thresholds
from crypta.lib.python.nirvana.nirvana_helpers.nirvana_transaction import NirvanaTransaction
from crypta.lib.python.yt import yt_helpers
from crypta.prism.lib.config import config
from crypta.prism.proto.thresholds_pb2 import TThresholds
from crypta.prism.services.training.lib import utils


logger = logging.getLogger(__name__)

histogram_query = """
$cids_cnt = (
    SELECT CAST(COUNT(*) AS Double)
    FROM `{raw_train_sample_table}`
);

INSERT INTO `{distribution_for_thresholds_table}`
WITH TRUNCATE

SELECT
    target,
    COUNT(*) / $cids_cnt AS cids_ratio
FROM `{raw_train_sample_table}`
GROUP BY target
ORDER BY target;

$rounded = (
    SELECT
        Math::Round(`Probability:Class=0`, -2) AS p1,
        Math::Round(`Probability:Class=1`, -2) AS p2,
        Math::Round(`Probability:Class=2`, -2) AS p3,
        Math::Round(`Probability:Class=3`, -2) AS p4,
        Math::Round(`Probability:Class=4`, -2) AS p5,
    FROM `{model_predictions_table}`
);

INSERT INTO `{model_predictions_for_thresholds_table}`
WITH TRUNCATE

SELECT
    p1,
    p2,
    p3,
    p4,
    p5,
    COUNT(*) AS cnt
FROM $rounded
WHERE p1 + p2 + p3 + p4 + p5 == 1
GROUP BY p1, p2, p3, p4, p5
ORDER BY p1, p2, p3, p4, p5;
"""


def write_segment_stats_to_yt(yt_client, stats, date):
    for stats_type, stats_name in (
        ('ratio', 'train_distribution'),
        ('threshold', 'classification_thresholds'),
    ):

        data = []
        for segment, stat_value in zip(utils.SEGMENTS, stats[stats_name]):
            data.append({
                'segment': segment,
                stats_type: stat_value,
            })

        yt_helpers.write_stats_to_yt(
            yt_client=yt_client,
            table_path=os.path.join(config.DATALENS_REALTIME_PRISM_DIR, stats_name),
            data_to_write=data,
            schema={
                'segment': 'string',
                stats_type: 'double',
            },
            date=date,
        )


def find(yt_client, yql_client, date, output):
    with NirvanaTransaction(yt_client) as transaction:
        yql_client.execute(
            query=histogram_query.format(
                raw_train_sample_table=config.RAW_TRAIN_SAMPLE_TABLE,
                model_predictions_table=config.MODEL_PREDICTIONS_TABLE,
                distribution_for_thresholds_table=config.DISTRIBUTION_FOR_THRESHOLDS_TABLE,
                model_predictions_for_thresholds_table=config.MODEL_PREDICTIONS_FOR_THRESHOLDS_TABLE,
            ),
            transaction=str(transaction.transaction_id),
            title='YQL realtime prism make predictions hist',
        )

        distribution = []
        for row in yt_client.read_table(config.DISTRIBUTION_FOR_THRESHOLDS_TABLE):
            distribution.append(row['cids_ratio'])

        logger.info('Start computing thresholds')

        thresholds = classification_thresholds.find_thresholds(
            table=pd.DataFrame(yt_client.read_table(config.MODEL_PREDICTIONS_FOR_THRESHOLDS_TABLE)),
            segments=utils.SEGMENTS,
            needed_recalls=np.array(distribution),
            constant_for_full_coverage=0.5,
        )

        write_segment_stats_to_yt(
            yt_client=yt_client,
            stats={
                'train_distribution': distribution,
                'classification_thresholds': thresholds,
            },
            date=date,
        )

        logger.info('Computed thresholds')

        thresholds_proto = TThresholds()
        thresholds_proto.Thresholds.extend(thresholds)
        with open(output, 'w') as output_file:
            output_file.write(json_format.MessageToJson(thresholds_proto, sort_keys=True))

        logger.info('Successfully dumped thresholds')
