#!/usr/bin/env python
# -*- coding: utf-8 -*-

from cached_property import cached_property

from crypta.lib.python.custom_ml.classification import (
    CustomClassificationModelTrainHelper,
    CustomClassificationParams,
)
from crypta.lib.python.custom_ml.classification.model_training import make_common_classification_train_sample_query
from crypta.lib.python.custom_ml.tools.training_utils import get_model_tag
from crypta.profile.runners.segments.lib.coded_segments.ml_tools.utils import resources
from crypta.profile.utils.utils import revert_dict


class LegalEntitiesModelTrainHelper(CustomClassificationModelTrainHelper):
    def __init__(self, yt, logger, date, storage_sample_path, train_sample_path):
        super(LegalEntitiesModelTrainHelper, self).__init__(yt=yt, logger=logger, date=date)
        self.storage_sample_path = storage_sample_path
        self.train_sample_path = train_sample_path
        self.model_name = 'legal_entities'

    @cached_property
    def model_params(self):
        resource_type = resources[self.model_name]

        return CustomClassificationParams(
            train_sample_path=self.train_sample_path,
            resource_type=resource_type,
            model_tag=get_model_tag(resource_type),
            metrics_group_name=self.get_metrics_group_name(self.model_name),
            model_description_in_sandbox='Model for legal entities segments classification',
            segment_id_to_name=self.segment_id_to_name,
            segment_name_to_id=revert_dict(self.segment_id_to_name),
            ordered_thresholds=None,
            make_train_sample_query=make_common_classification_train_sample_query.format(
                input_table=self.storage_sample_path,
                output_table=self.train_sample_path,
                additional_conditions='',
            ),
        )
