import os
from abc import (
    ABCMeta,
    abstractmethod,
)

from cached_property import cached_property
from google.protobuf.json_format import MessageToDict
import six

from crypta.lib.python.custom_ml import training_config
from crypta.lib.python.custom_ml.classification.model_application import (
    apply_model_to_profiles,
    get_initial_segments,
    get_modeled_segments,
    get_percentiles_for_currently_computed_segments,
)
from crypta.lib.python.custom_ml.classification.prepare_train_sample import (
    prepare_sample_by_puid,
    prepare_training_sample_table_from_audience,
    prepare_training_sample_table_from_file,
)
from crypta.lib.python.custom_ml.classification.train_helper import (
    CustomClassificationModelTrainHelper,
    CustomClassificationParams,
)
import crypta.lib.proto.identifiers.id_pb2 as IdProto
from crypta.lib.python.custom_ml.proto import classification_pb2
from crypta.lib.python.custom_ml.samples.new_sample import (
    add_new_sample_for_existing_industry,
    merge_training_samples,
)
from crypta.lib.python.custom_ml.tools import (
    training_utils,
    metrics,
    utils,
)
from crypta.siberia.bin.common.describing.experiment.proto import describing_experiment_pb2
from crypta.siberia.bin.common.siberia_client import SiberiaClient


make_common_classification_train_sample_query = training_utils.make_common_train_sample_query.format(
    target_processing='train_sample.segment_name',
    target='segment_name',
    input_table='{input_table}',
    output_table='{output_table}',
    additional_conditions='{additional_conditions}',
)


class AnyCustomClassificationModelTrainHelper(CustomClassificationModelTrainHelper):
    def __init__(self, yt, train_sample_path,
                 logger=utils.get_stderr_logger(), model_name='new_model'):
        super(AnyCustomClassificationModelTrainHelper, self).__init__(yt=yt, logger=logger, date=None)
        self.train_sample_path = train_sample_path
        self.model_name = model_name

    @cached_property
    def model_params(self):

        return CustomClassificationParams(
            train_sample_path=self.train_sample_path,
            metrics_group_name=self.get_metrics_group_name(self.model_name),
            segment_id_to_name=self.segment_id_to_name,
            segment_name_to_id=training_utils.revert_dict(self.segment_id_to_name),
            ordered_thresholds=None,
            model_description_in_sandbox=None,
            make_train_sample_query=None,
            model_tag=None,
            resource_type=None,
        )


def prepare_train_sample(yql_client, sample_by_puid_path, sample_by_yuid_path, additional_conditions=''):
    """
    Get sample for training (sample_by_yuid) from sample_by_puid
    """
    yql_client.execute(
        query=make_common_classification_train_sample_query.format(
            input_table=sample_by_puid_path,
            output_table=sample_by_yuid_path,
            additional_conditions=additional_conditions,
        ),
        title='YQL get train sample from sample_by_puid',
    )


def log_results(logger, output, string_info):
    logger.info(string_info)
    output.write(six.ensure_binary(string_info))


class TrainAndValidateModelHelper:
    __metaclass__ = ABCMeta

    def __init__(
        self,
        yt_client,
        yql_client,
        output_dir,
        raw_sample_table=None,
        raw_sample_file=None,
        audience_id=None,
        validate_segments=False,
        crypta_identifier_udf_url=training_config.IDENTIFIER_UDF_PATH,
        number_of_top_features=training_config.DEFAULT_NUMBER_OF_TOP_FEATURES,
        logger=None,
        send_results_to_api=False,
    ):
        self.yt_client = yt_client
        self.yql_client = yql_client
        self.send_results_to_api = send_results_to_api
        self.api = None
        if self.send_results_to_api:
            self.api = utils.get_api(os.getenv('CRYPTA_ENVIRONMENT') or 'testing')
        self.crypta_identifier_udf_url = crypta_identifier_udf_url
        self.logger = logger or utils.get_stderr_logger()
        self.file_output = six.BytesIO()

        self.raw_sample_table = raw_sample_table or self.get_output_table_path(training_config.RAW_SAMPLE)
        self.raw_sample_file = raw_sample_file
        self.audience_id = audience_id
        self.output_dir = output_dir
        self.sample_id = os.path.basename(output_dir)
        self.validate_segments = validate_segments
        self.number_of_top_features = number_of_top_features

    def get_output_table_path(self, table_name):
        return os.path.join(self.output_dir, table_name)

    def send_error(self, message):
        error_proto = classification_pb2.TTrainingError(
            sample_id=self.sample_id,
            message=message,
        )

        if not self.send_results_to_api:
            return

        self.logger.info(self.api.lab.writeTrainingSampleError(**MessageToDict(error_proto)).result())

    def send_metrics(self, metrics_proto):
        parameters = MessageToDict(metrics_proto)
        parameters['topFeatures'] = ', '.join(parameters['topFeatures'])

        if not self.send_results_to_api:
            return

        self.logger.info(self.api.lab.writeTrainingSampleMetrics(**parameters).result())

    def get_sample_by_puid(self):
        if self.raw_sample_file is not None:
            prepare_training_sample_table_from_file(
                yt_client=self.yt_client,
                raw_sample_file=self.raw_sample_file,
                raw_sample_table=self.raw_sample_table,
            )
        elif self.audience_id is not None:
            prepare_training_sample_table_from_audience(
                yt_client=self.yt_client,
                audience_ids=self.audience_id,
                raw_sample_table=self.raw_sample_table,
                logger=self.logger,
            )

        self.logger.info('Calculating sample_by_puid')
        matching_stats, classes_stats = prepare_sample_by_puid(
            yt_client=self.yt_client,
            yql_client=self.yql_client,
            raw_sample_table=self.raw_sample_table,
            sample_by_puid_table=self.get_output_table_path(training_config.SAMPLE_BY_PUID),
            crypta_identifier_udf_url=self.crypta_identifier_udf_url,
            logger=self.logger,
        )
        log_results(self.logger, self.file_output, 'Matching stats:\n{}\n'.format(metrics.pandas_to_startrek(matching_stats)))
        log_results(self.logger, self.file_output, 'Classes stats:\n{}\n'.format(metrics.pandas_to_startrek(classes_stats)))

        return matching_stats, classes_stats

    def train_new_model(self, sample_by_puid):
        sample_by_yuid = self.get_output_table_path(training_config.SAMPLE_BY_YUID)
        prepare_train_sample(
            yql_client=self.yql_client,
            sample_by_puid_path=sample_by_puid,
            sample_by_yuid_path=sample_by_yuid,
        )

        new_train_helper = AnyCustomClassificationModelTrainHelper(
            yt=self.yt_client,
            train_sample_path=sample_by_yuid,
            logger=self.logger,
            model_name='new_model',
        )
        new_model, new_metrics = new_train_helper.get_model_and_metrics()
        new_top_features = new_train_helper.get_top_features(
            model=new_model,
            number_of_features=self.number_of_top_features,
        )

        return new_model, metrics.convert_solomon_metrics_to_dict(new_metrics), new_top_features

    def save_model(self, new_model):
        self.logger.info('Saving model in the output folder.')
        yt_model_path = os.path.join(self.output_dir, 'catboost_model.bin')
        utils.save_catboost_model(self.yt_client, new_model, yt_model_path)

        return yt_model_path

    def report_status(self, status):
        if self.send_results_to_api:
            self.logger.info(self.api.lab.reportTrainingStatus(
                sampleId=self.sample_id,
                status=status,
            ).result())

    def report_start(self):
        self.report_status(classification_pb2.ETrainingSampleState.Name(classification_pb2.ETrainingSampleState.TRAINING))

    def report_success(self):
        self.report_status(classification_pb2.ETrainingSampleState.Name(classification_pb2.ETrainingSampleState.READY))

    @abstractmethod
    def run(self):
        """
        Method to train and validate model.
        """
        pass


class NewModelTrainValidateHelper(TrainAndValidateModelHelper):
    def __init__(
        self,
        yt_client,
        yql_client,
        output_dir,
        raw_sample_table=None,
        raw_sample_file=None,
        audience_id=None,
        validate_segments=False,
        positive_output_segment_size=training_config.DEFAULT_OUTPUT_SEGMENT_SIZE,
        negative_output_segment_size=training_config.DEFAULT_OUTPUT_SEGMENT_SIZE,
        crypta_identifier_udf_url=training_config.IDENTIFIER_UDF_PATH,
        number_of_top_features=training_config.DEFAULT_NUMBER_OF_TOP_FEATURES,
        logger=None,
        send_results_to_api=False,
    ):
        TrainAndValidateModelHelper.__init__(
            self,
            yt_client,
            yql_client,
            output_dir,
            raw_sample_table,
            raw_sample_file,
            audience_id,
            validate_segments,
            crypta_identifier_udf_url,
            number_of_top_features,
            logger,
            send_results_to_api,
        )

        self.positive_output_segment_size = positive_output_segment_size
        self.negative_output_segment_size = negative_output_segment_size

    def prepare_segments(self):
        get_initial_segments(
            yql_client=self.yql_client,
            train_sample_by_yuid_table=self.get_output_table_path(training_config.SAMPLE_BY_YUID),
        )
        get_modeled_segments(
            yql_client=self.yql_client,
            predictions_table_path=self.get_output_table_path(training_config.RESULTING_SEGMENTS),
            positive_output_segment_size=self.positive_output_segment_size,
            negative_output_segment_size=self.negative_output_segment_size,
        )

    def describe_segments(self):
        tvm_ticket = utils.get_tvm_client().get_service_ticket_for('siberia')
        siberia_client = SiberiaClient(training_config.SIBERIA_HOST[os.getenv('CRYPTA_ENVIRONMENT')], training_config.SIBERIA_PORT)

        for origin_type in ('initial', 'modeled'):
            for target_type in ('positive', 'negative'):
                segment_table = os.path.join(self.output_dir, 'segments', '{}_{}'.format(origin_type, target_type))

                ids = IdProto.TIds(Ids=[
                    IdProto.TId(Type='yandexuid', Value=row['yandexuid'])
                    for row in self.yt_client.read_table(self.yt_client.TablePath(segment_table, end_index=int(1e5)))
                ])

                user_set_id = siberia_client.user_sets_describe_ids(
                    ids,
                    experiment=describing_experiment_pb2.TDescribingExperiment(CryptaIdUserDataVersion="by_crypta_id"),
                    tvm_ticket=tvm_ticket,
                ).UserSetId

                self.logger.info(self.api.lab.addSegmentDescription(
                    sampleId=self.sample_id,
                    originType=origin_type.upper(),
                    targetType=target_type.upper(),
                    userSetId=user_set_id,
                    rowsCount=int(self.yt_client.row_count(segment_table)),
                ).result())

    def process_training_metrics(
        self,
        matching_stats,
        new_metrics,
        new_top_features,
    ):
        metrics_proto = metrics.save_metrics(
            yt_client=self.yt_client,
            output_dir=self.output_dir,
            matching_stats=matching_stats,
            model_metrics=new_metrics,
            model_top_features=new_top_features,
        )
        self.send_metrics(metrics_proto)

        formatted_metrics = metrics.format_metrics_comparison(
            new_metrics=new_metrics,
            new_top_features=new_top_features,
        )

        log_results(self.logger, self.file_output, formatted_metrics)

    def validate_resulting_segments(self, yt_model_path):
        self.logger.info('Computing table with segments predictions')
        apply_model_to_profiles(
            yt_client=self.yt_client,
            yql_client=self.yql_client,
            output_path=self.get_output_table_path(training_config.RESULTING_SEGMENTS),
            yt_model_path=yt_model_path,
        )

        formatted_metrics = metrics.compute_segments_metrics_for_new_model(
            yt_client=self.yt_client,
            yql_client=self.yql_client,
            predictions_table_path=self.get_output_table_path(training_config.RESULTING_SEGMENTS),
            train_sample_path=self.get_output_table_path(training_config.SAMPLE_BY_YUID),
        )

        log_results(self.logger, self.file_output, formatted_metrics)
        metrics.write_metrics(self.yt_client, self.output_dir, self.file_output)

    def run(self):
        self.report_start()

        try:
            matching_stats, classes_stats = self.get_sample_by_puid()
        except AssertionError as error:
            message = error.args[0]

            self.logger.info('Data requirements are violated: {}'.format(message))
            self.send_error(message)

            return

        new_model, new_metrics, new_top_features = self.train_new_model(
            sample_by_puid=self.get_output_table_path(training_config.SAMPLE_BY_PUID),
        )
        self.process_training_metrics(
            matching_stats=matching_stats,
            new_metrics=new_metrics,
            new_top_features=new_top_features,
        )

        if not self.validate_segments:
            metrics.write_metrics(self.yt_client, self.output_dir, self.file_output)
            return

        yt_model_path = self.save_model(new_model)

        self.validate_resulting_segments(yt_model_path)

        self.prepare_segments()

        if self.send_results_to_api:
            self.describe_segments()

        self.report_success()


class ExistingModelTrainValidateHelper(TrainAndValidateModelHelper):
    def __init__(
        self,
        yt_client,
        yql_client,
        industry_model_name,
        output_dir,
        partner=None,
        login=None,
        if_make_decision=False,
        raw_sample_table=None,
        raw_sample_file=None,
        audience_id=None,
        validate_segments=False,
        crypta_identifier_udf_url=training_config.IDENTIFIER_UDF_PATH,
        number_of_top_features=training_config.DEFAULT_NUMBER_OF_TOP_FEATURES,
        logger=None,
        send_results_to_api=False,
        logins_to_share=None,
    ):
        TrainAndValidateModelHelper.__init__(
            self,
            yt_client,
            yql_client,
            output_dir,
            raw_sample_table,
            raw_sample_file,
            audience_id,
            validate_segments,
            crypta_identifier_udf_url,
            number_of_top_features,
            logger,
            send_results_to_api,
        )
        self.industry_model_name = industry_model_name
        self.industry_yt_dir = utils.get_industry_dir_from_industry_name(self.industry_model_name)
        self.partner = partner
        self.login = login
        self.if_make_decision = if_make_decision
        self.logins_to_share = [login_to_share.strip() for login_to_share in logins_to_share.split(',')] if logins_to_share else []

    def get_industry_table_path(self, table_name):
        return os.path.join(self.industry_yt_dir, table_name)

    def train_existing_model(self):
        if not self.yt_client.exists(self.industry_yt_dir):
            raise ValueError('Invalid industry folder path. Expected to find on path: {}'.format(self.industry_yt_dir))

        train_helper = AnyCustomClassificationModelTrainHelper(
            yt=self.yt_client,
            train_sample_path=self.get_industry_table_path(training_config.SAMPLE_BY_YUID),
            logger=self.logger,
            model_name='existing_model',
        )
        existing_model, existing_metrics = train_helper.get_model_and_metrics()
        existing_top_features = train_helper.get_top_features(
            model=existing_model,
            number_of_features=self.number_of_top_features,
        )

        return metrics.convert_solomon_metrics_to_dict(existing_metrics), existing_top_features

    def process_training_metrics(
        self,
        existing_metrics,
        existing_top_features,
        new_metrics,
        new_top_features,
    ):
        formatted_metrics = metrics.format_metrics_comparison(
            existing_metrics=existing_metrics,
            existing_top_features=existing_top_features,
            new_metrics=new_metrics,
            new_top_features=new_top_features,
        )

        log_results(self.logger, self.file_output, formatted_metrics)

    def make_sample_adding_decision(self, yt_model_path):
        existing_predictions_table_path = os.path.join(
            training_config.EXISTING_MODEL_PREDICTIONS,
            '{}ModelApplication'.format(self.industry_model_name),
        )
        if not self.yt_client.exists(existing_predictions_table_path):
            raise ValueError('Invalid existing segments table path: {}'.format(existing_predictions_table_path))

        percentiles = get_percentiles_for_currently_computed_segments(
            yt_client=self.yt_client,
            yql_client=self.yql_client,
            segments_table=existing_predictions_table_path,
        )

        self.logger.info('Computing table with segments predictions')
        apply_model_to_profiles(
            yt_client=self.yt_client,
            yql_client=self.yql_client,
            output_path=self.get_output_table_path(training_config.RESULTING_SEGMENTS),
            yt_model_path=yt_model_path,
            percentiles=percentiles,
        )

        self.logger.info('Computing sample_by_yuid based on the new sample to validate segments on the new data')
        prepare_train_sample(
            yql_client=self.yql_client,
            sample_by_puid_path=self.get_output_table_path(training_config.SAMPLE_BY_PUID),
            sample_by_yuid_path=self.get_output_table_path(training_config.SAMPLE_BY_YUID_VALIDATION),
        )

        formatted_metrics, sample_should_be_added = metrics.compute_segments_metrics_for_existing_model(
            yt_client=self.yt_client,
            yql_client=self.yql_client,
            existing_predictions_table_path=existing_predictions_table_path,
            new_predictions_table_path=self.get_output_table_path(training_config.RESULTING_SEGMENTS),
            initial_train_sample_path=self.get_industry_table_path(training_config.SAMPLE_BY_YUID),
            new_train_sample_path=self.get_output_table_path(training_config.SAMPLE_BY_YUID_VALIDATION),
        )

        log_results(self.logger, self.file_output, formatted_metrics)
        metrics.write_metrics(self.yt_client, self.output_dir, self.file_output)

        return sample_should_be_added

    def share_segments(self):
        if not self.send_results_to_api or len(self.logins_to_share) == 0:
            return

        segments = self.api.lab.getAllSegments().result()
        trainable_segments = set()
        for segment in segments:
            for export in segment.exports.exports:
                if export.keywordId == 557 and self.industry_model_name in export.tags:
                    trainable_segments.add(str(export.segmentId))

        for segment in trainable_segments:
            self.logger.info(self.api.lab.createGrants(segmentId=segment, logins=self.logins_to_share).result())

    def run(self):
        self.report_start()

        try:
            matching_stats, classes_stats = self.get_sample_by_puid()
        except AssertionError as error:
            self.logger.info('Data requirements are violated: %s', error)
            self.send_error(str(error))
            return

        existing_metrics, existing_top_features = self.train_existing_model()

        merge_training_samples(
            yql_client=self.yql_client,
            storage_sample_by_puid_path=self.get_industry_table_path(training_config.SAMPLE_BY_PUID),
            new_sample_by_puid_path=self.get_output_table_path(training_config.SAMPLE_BY_PUID),
            combined_samples_path=self.get_output_table_path(training_config.COMBINED_SAMPLE_BY_PUID),
        )

        new_model, new_metrics, new_top_features = self.train_new_model(
            sample_by_puid=self.get_output_table_path(training_config.COMBINED_SAMPLE_BY_PUID),
        )

        self.process_training_metrics(
            existing_metrics=existing_metrics,
            existing_top_features=existing_top_features,
            new_metrics=new_metrics,
            new_top_features=new_top_features,
        )

        yt_model_path = self.save_model(new_model)

        if not self.validate_segments:
            metrics.write_metrics(self.yt_client, self.output_dir, self.file_output)
            return

        sample_should_be_added = self.make_sample_adding_decision(yt_model_path)
        metrics_proto = metrics.save_metrics(
            yt_client=self.yt_client,
            output_dir=self.output_dir,
            matching_stats=matching_stats,
            model_metrics=new_metrics if sample_should_be_added else existing_metrics,
            model_top_features=new_top_features if sample_should_be_added else existing_top_features,
            model_type='new' if sample_should_be_added else 'existing',
        )
        self.send_metrics(metrics_proto)

        if self.if_make_decision:
            add_new_sample_for_existing_industry(
                yt_client=self.yt_client,
                yql_client=self.yql_client,
                industry_model_name=self.industry_model_name,
                new_sample_by_puid_table=self.get_output_table_path(training_config.SAMPLE_BY_PUID),
                partner=self.partner,
                login=self.login,
                logger=self.logger,
                add_to_training=sample_should_be_added,
                retrain_model=True,
                use_addition_date=True,
            )

        self.share_segments()
        self.report_success()
