from enum import Enum, unique
from textwrap import dedent
from datacloud.config.yt import PRODUCTION_ROOT
from datacloud.dev_utils.yql.yql_helpers import execute_yql, create_yql_client
from datacloud.dev_utils.yt import yt_utils
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.features.dssm.prepare_title_url import run_prepare_title_url
from datacloud.features.dssm.dot_product import run_dot_product
from datacloud.features.dssm.join_cids import join_cids
from datacloud.features.dssm.get_max_features import run_get_max_features
from datacloud.features.dssm.join_scores import join_dssm_scores
from datacloud.features.dssm.config import DSSMConfig, YT_PATH2MODEL


logger = get_basic_logger(__name__)


def clear_nan_map(rec):
    if 'nan' not in rec['vector']:
        yield rec


@unique
class DSSMFeaturesBuildSteps(Enum):
    step_1_run_prepare_title_url = 'step_1_run_prepare_title_url'
    step_2_dssm_step = 'step_2_dssm_step'
    step_3_run_dot_product = 'step_3_run_dot_product'
    step_4_join_cids = 'step_4_join_cids'
    step_5_run_get_max_features = 'step_5_run_get_max_features'


DSSM_DEFAULT_FEATURES_BUILD_STEPS = (
    DSSMFeaturesBuildSteps.step_1_run_prepare_title_url,
    DSSMFeaturesBuildSteps.step_2_dssm_step,
    DSSMFeaturesBuildSteps.step_3_run_dot_product,
    DSSMFeaturesBuildSteps.step_4_join_cids,
    DSSMFeaturesBuildSteps.step_5_run_get_max_features
)


class DSSMProcessor(object):
    def __init__(self, config, yt_client, yql_client=None):
        self.config = config
        self.yt_client = yt_client
        self.yql_client = yql_client or create_yql_client(yt_client=self.yt_client)

    def step_1_run_prepare_title_url(self):
        if self.config.use_cloud_nodes:
            logger.warning('Attention! Using cloud nodes!')
        run_prepare_title_url(self.config, self.yt_client)

    def step_2_dssm_step(self):
        yql_query = dedent("""

            PRAGMA yt.DataSizePerJob="25M";

            $dssm_model = Dssm::LoadModel(FilePath("model.dssm"));

            $table = (
                SELECT
                    hash ?? "" as hash,
                    title ?? "" as title,
                    url ?? "" as url
                FROM `%(input_table)s`
            );

            INSERT INTO `%(output_table)s` WITH TRUNCATE
            select
                t.hash,
                String::JoinFromList(ListMap(Dssm::Apply(
                    $dssm_model,
                    AsStruct(title as doc_title, url as doc_url, "" as query, "" as doc_uta_url),
                    "doc_embedding_bigrams"
                ), ($x) -> {return CAST($x as String)}), ' ') AS vector
            FROM $table as t;
        """)
        execute_yql(query=yql_query, yql_client=self.yql_client, params=dict(
            input_table=self.config.yuid2title_url4_table,
            output_table=self.config.vectors_table
        ), urls2attach=[(self.config.model_url, 'model.dssm')], set_owners=False, syntax_version=1)

        # Hack to remove nans from table
        self.yt_client.run_map(
            clear_nan_map,
            self.config.vectors_table,
            self.config.vectors_table,
        )

        assert self.yt_client.is_empty(self.config.vectors_table) is False, 'Error dssm result table {} is empty.'.format(self.config.vectors_table)
        self.yt_client.remove(self.config.yuid2title_url4_table)

    def step_3_run_dot_product(self):
        if self.config.use_cloud_nodes:
            logger.warning('Attention! Using cloud nodes!')
        run_dot_product(self.config, self.yt_client)

    def step_4_join_cids(self):
        if self.config.use_cloud_nodes:
            logger.warning('Attention! Using cloud nodes!')
        join_cids(self.config, self.yt_client)

    def step_5_run_get_max_features(self):
        if self.config.use_cloud_nodes:
            logger.warning('Attention! Using cloud nodes!')
        run_get_max_features(self.config, self.yt_client)

    def step_6_run_join_scores(self, weekly_tables_to_take=25):  # 25 weekly tables, 175 days
        if self.config.use_cloud_nodes:
            logger.warning('Attention! Using cloud nodes!')
        join_dssm_scores(self, tables_to_take=weekly_tables_to_take)

    def collect_garbage(self):
        logger.info('Started garbage collect')
        if self.config.garbage_collect_on and self.yt_client.exists(self.config.tmp_dir):
            self.yt_client.remove(self.config.tmp_dir)
            logger.info('Garbage collected')
        else:
            logger.info('Garbage collect turned off!')

    def create_folders(self):
        yt_utils.create_folders(
            [
                self.config.tmp_dir,
                self.config.weekly_dir,
                self.config.ready_dir
            ],
            self.yt_client
        )


class DSSMTables:
    def __init__(self, date_str, base_root=PRODUCTION_ROOT, days_to_take=None,
                 test=False, is_retro=False, yt_client=None, yql_client=None,
                 model_url=YT_PATH2MODEL, retro_tag='', garbage_collect_on=True,
                 use_cloud_nodes=False, yuid2cid_path='',
                 weekly_dir='', ready_table_path=''):
        """Use retro_tag to differ your tmp folder from others
        """
        self.config = DSSMConfig(
            date_str, base_root, days_to_take,
            is_retro,
            model_url, retro_tag, garbage_collect_on, use_cloud_nodes,
            yuid2cid_path=yuid2cid_path,
            weekly_dir=weekly_dir,
            ready_table_path=ready_table_path
        )
        self.processor = DSSMProcessor(self.config, yt_client, yql_client)
        self.log_folders = self.config.log_folders
        self.grep_root = self.config.grep_root

    @property
    def cloud_nodes_spec(self):
        return self.config.cloud_nodes_spec()

    def create_folders(self):
        self.processor.create_folders()

    def step_1_run_prepare_title_url(self):
        self.processor.step_1_run_prepare_title_url()

    def step_2_dssm_step(self):
        self.processor.step_2_dssm_step()

    def step_3_run_dot_product(self):
        self.processor.step_3_run_dot_product()

    def step_4_join_cids(self):
        self.processor.step_4_join_cids()

    def step_5_run_get_max_features(self):
        self.processor.step_5_run_get_max_features()

    def step_6_run_join_scores(self, tables_to_take=25):  # 25 weekly tables, 175 days
        self.processor.step_6_run_join_scores(tables_to_take)

    def collect_garbage(self):
        self.processor.collect_garbage()

    def get_grep_tables(self, yt_client, config):
        return self.config.get_grep_tables(yt_client, config)


def build_retro_vectors(processor, steps_to_run=DSSM_DEFAULT_FEATURES_BUILD_STEPS):
    processor.create_folders()

    if DSSMFeaturesBuildSteps.step_1_run_prepare_title_url in steps_to_run:
        processor.step_1_run_prepare_title_url()
    logger.info('\n=== Step 1 done ===\n')

    if DSSMFeaturesBuildSteps.step_2_dssm_step in steps_to_run:
        processor.step_2_dssm_step()
    logger.info('\n=== Step 2 done ===\n')

    if DSSMFeaturesBuildSteps.step_3_run_dot_product in steps_to_run:
        processor.step_3_run_dot_product()
    logger.info('\n=== Step 3 done ===\n')

    if DSSMFeaturesBuildSteps.step_4_join_cids in steps_to_run:
        processor.step_4_join_cids()
    logger.info('\n=== Step 4 done ===\n')

    if DSSMFeaturesBuildSteps.step_5_run_get_max_features in steps_to_run:
        processor.step_5_run_get_max_features()
    logger.info('\n=== Step 5 done ===\n')

    processor.collect_garbage()
