# coding=utf-8

import datetime
import logging

from yt.wrapper import YtHttpResponseError
import yt.wrapper as yt
from irt.bannerland.options import get_option as get_bl_opt
import irt.broadmatching.common_options
from bm.yt_tools import get_yt_bm_config, get_cdict_generation_params
from bm.bmyt import BMYT


from generate_norms import generate_norms
from generate_counts import generate_counts, merge_counts
from generate_counts_query import generate_counts_query

from generate_bnr_counts import generate_bnr_count
from generate_snorms import generate_snorms_and_syns
from generate_categs import generate_categs
from generate_tails import generate_tails


logger = logging.getLogger(__name__)


CDICT_GENERATION_PATH = get_cdict_generation_params()['cdict_generations_yt_path']


class GetKeyReducer:
    def __init__(self):
        pass

    def __call__(self, key, recs):
        for rec in recs:
            yield key
            return


# генерирует все входные таблицы для всех cdict
class PrepareCdictInput:
    def __init__(self):
        res = irt.broadmatching.common_options.get_options()['QueryLogStat_params']
        self.QueryLogStat_params = res['QueryLogStat_params']
        self.yt_direct_banners = get_bl_opt('yt_direct_banners')
        self.yt_direct_campaigns = get_bl_opt('yt_direct_campaigns')

        self.languages = self.QueryLogStat_params['languages']

        # set YT config
        yt_config = get_yt_bm_config()

        if 'token_path' in yt_config:
            yt.config['token_path'] = yt_config['token_path']
        if 'token' in yt_config:
            yt.config['token'] = yt_config['token']
        yt.config['proxy']['url'] = yt_config["proxy"]["url"]
        self.bmyt_cl = BMYT(process_count=4)

        self.yt_client = self.bmyt_cl.yt_client

        self.cdict_gen_dir = CDICT_GENERATION_PATH
        self.lock_table_path = self.cdict_gen_dir + "/_lock_gen_cdict"

    def finish_yt_regeneration(self):
        now_dt_str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        self.yt_client.set_attribute(self.lock_table_path, 'last_generation_date', now_dt_str)

    def try_get_yt_lock(self):
        try:
            self.yt_client.lock(self.lock_table_path, mode='exclusive')
        except YtHttpResponseError:
            logger.error('get lock fail, generation skipped')
            return False
        return True

    def get_last_generation(self):
        str_last_gen_time = self.yt_client.get_attribute(self.lock_table_path, "last_generation_date")
        return datetime.datetime.strptime(str_last_gen_time, '%Y-%m-%d %H:%M:%S')

    def is_old(self):
        min_time_generation_hour = get_cdict_generation_params()["min_time_generation_hour"]
        min_delta_datetime_generation = datetime.timedelta(hours=min_time_generation_hour)

        self.yt_client.create('table', self.lock_table_path, ignore_existing=True,
                              attributes={"last_generation_date": "2018-01-01 01:01:01"})

        last_gen_date_time = self.get_last_generation()
        now_datetime = datetime.datetime.now()
        delta_datetime = (now_datetime - last_gen_date_time)

        if delta_datetime <= min_delta_datetime_generation:
            logger.info("Last generation is up-to-time (%s), generation skipped", last_gen_date_time)

        return delta_datetime > min_delta_datetime_generation

    def generate_counts_data_yt(self):
        for lang in self.languages:
            params = self.QueryLogStat_params['languages'][lang]
            logger.info('generate_counts_data_yt for %s', lang)

            generate_norms(self.yt_client,
                           params['yt_advq_phits_dir'],

                           params['yt_table_norms'],
                           params['yt_table_norms_full'],
                           dest_table_orig=params['yt_table_norms_orig'])

            generate_counts(self.yt_client,
                            params['yt_table_norms_full'],
                            params['yt_table_counts_full'],
                            params['yt_table_counts'],
                            dst_table_geo=params['yt_table_counts_geo']
                            )

            generate_counts_query(self.yt_client,
                                  params['yt_table_norms'],
                                  params['yt_table_counts_query'])
            logger.info('/ generate_counts_data_yt for %s', lang)
        counts_full_tables = map(lambda x: self.QueryLogStat_params['languages'][x]['yt_table_counts_full'], self.languages)
        merge_counts(self.yt_client,
                     counts_full_tables,
                     self.QueryLogStat_params['yt_table_counts_full']
                     )

    def generate_counts_mob_data_yt(self):
        for lang in self.languages:
            params = self.QueryLogStat_params['languages'][lang]
            logger.info('generate_counts_mob_data_yt for %s', lang)

            generate_norms(self.yt_client,
                           params['yt_advq_phits_dir'],

                           params['yt_table_norms_mob'],
                           params['yt_table_norms_full_mob'],
                           is_mobile=True)

            generate_counts(self.yt_client,
                            params['yt_table_norms_full_mob'],
                            params['yt_table_counts_full_mob'],
                            params['yt_table_counts_mob']
                            )

            logger.info('/ generate_counts_mob_data_yt for %s', lang)

    def generate_bnr_counts(self):
        logger.info('generate_bnr_counts')

        generate_bnr_count(self.bmyt_cl,
                           self.yt_direct_banners,
                           self.yt_direct_campaigns,
                           self.QueryLogStat_params['yt_table_bnr_counts'],
                           self.QueryLogStat_params['languages']['ru']['yt_table_bnr_counts'],
                           self.QueryLogStat_params['languages']['tr']['yt_table_bnr_counts'],)
        logger.info('/ generate_bnr_counts')

    def generate_flags_yt(self):
        for lang in self.languages:
            params = self.QueryLogStat_params['languages'][lang]
            logger.info('generating generate_flags_yt for %s', lang)
            with self.yt_client.TempTable() as tmp_table:
                bm_mapper = {
                    'begin': """
                        use BaseProject;
                        my $lang = '""" + lang + """' // '';

                        my $proj = BaseProject->new({
                            load_dicts => 1,
                            allow_lazy_dicts => 1,
                            load_languages => [ qw(ru en tr) ],
                        });
                        $self->{proj} = $proj;
                        $self->{language} = $lang ? $proj->get_language($lang) : $proj->default_language;
                    """,
                    'mapper': '''
                        my $text = $r->{'norm'};
                        my $phr = $self->{language}->phrase($text);
                        my $flags =
                            ($phr->is_wide_phrase ? $self->{proj}->cdict_client->{wide_phrase_flag} : 0) |
                            ($phr->is_good_phrase ? 0 : $self->{proj}->cdict_client->{bad_phrase_flag});

                        my $out_r;
                        if ($flags) {
                            $out_r->{'phrase'} = $text;
                            $out_r->{'flags'} = $flags;

                            yield($out_r => YT_TABLE_FLAGS);
                        }
                    ''',
                    'dst_names': ['YT_TABLE_FLAGS'],
                    'dst_fields': [{"phrase": str,
                                   "flags": str, }]
                }

                self.bmyt_cl.run_bm_map(
                    bm_mapper,
                    params['yt_table_counts'],
                    tmp_table
                )

                self.bmyt_cl.yt_client.run_map_reduce(
                    None,
                    GetKeyReducer(),
                    tmp_table,
                    params['yt_table_flags'],
                    reduce_by=['phrase', 'flags']
                )

            logger.info('/ generating generate_flags_yt for %s', lang)

    def generate_snorms_yt(self):
        for lang in self.languages:
            params = self.QueryLogStat_params['languages'][lang]
            logger.info('generate_snorms_yt for %s', lang)
            generate_snorms_and_syns(self.bmyt_cl,
                                     params['yt_table_counts'],
                                     params['yt_table_snorms'],
                                     params['yt_table_syns'],
                                     lang)
            logger.info('/ generate_snorms_yt for %s', lang)

    def generate_categs_yt(self):
        for lang in self.languages:
            params = self.QueryLogStat_params['languages'][lang]
            logger.info('generate_categs_yt data for %s', lang)
            generate_categs(self.bmyt_cl,
                            params['yt_table_syns'],
                            params['yt_table_categs'],
                            params['yt_table_categs_regions'],
                            params['yt_table_syns_categs'],
                            params['yt_table_regions'],
                            lang=lang)
            logger.info('/ generate_categs_yt data for %s', lang)

    def generate_tails_yt(self):
        for lang in self.languages:
            params = self.QueryLogStat_params['languages'][lang]
            logger.info('generate_tails_yt for %s', lang)
            generate_tails(self.bmyt_cl.yt_client,
                           params['yt_table_syns_categs'],
                           params['yt_table_tails'])
            logger.info('/ generate_tails_yt for %s', lang)

    def generate_harm_yt(self):
        with self.bmyt_cl.yt_client.Transaction() as tx:
            for lang in self.languages:
                logger.info('generate_harm_yt for %s', lang)

                params = self.QueryLogStat_params['languages'][lang]
                in_table_path = params['yt_table_norms_orig']
                out_table_path = params['yt_table_harm']

                yql_query = '''
                PRAGMA yt.ForceInferSchema;
                use hahn;

                INSERT INTO `{out_table}` WITH TRUNCATE
                SELECT
                    norm,
                    orig_text as orig
                FROM
                    `{in_table}`
                WHERE
                    freq_query>=10;
                '''.format(in_table=in_table_path,
                           out_table=out_table_path)

                self.bmyt_cl.do_yql(yql_query,
                                    title='generate harm for ' + lang,
                                    transaction_id=tx.transaction_id)

                logger.info('/ generate_harm_yt for %s', lang)

    def generate_cdict_input_full(self):
        logger.info('generate_cdict_input_full')
        self.generate_counts_data_yt()
        self.generate_counts_mob_data_yt()
        self.generate_bnr_counts()
        self.generate_flags_yt()
        self.generate_snorms_yt()
        self.generate_categs_yt()
        self.generate_tails_yt()
        self.generate_harm_yt()
        logger.info('/ generate_cdict_input_full')
