import codecs
import datetime
import json
import logging
import os
import sys
import subprocess
import sqlite3
import multiprocessing
import time
from collections import Counter

import yt as yt_main
from yt.wrapper.common import chunk_iter_stream
import yt.wrapper as yt
from mapreducelib import MapReduce
from yql.api.v1.client import YqlClient

from beast_sample import BeastSample
from settings import SerpSettings, GoogleSerpSettingsTouch
from parse_serp_mr_jobs import SerpParserCombiner
from SerpParser import SerpParser, SkeletonEvaluator
from transliterator import Transliterator


class AnatomyRazladki:
    MR_DIR = '//home/search-functionality/serp_anatomy/razladki/'
    BACKUP_DIR = '//home/search-functionality/serp_anatomy'
    SKELETON_FEATURES_FNAME = 'backup_dir/skeleton_features.tsv'
    DATE_INPUT_FORMAT = '%Y%m%d'
    SERP_SETTINGS = GoogleSerpSettingsTouch()
    RAZLADKI_PREFIX = 'goog'
    NANO_SESSIONS_TABLE_PATTERN = '//user_sessions/pub/nano_sessions/daily/{date_yt}/web/clean'
    PARALLEL_PROCESSES = 5
    DO_RECALCULATE = False
    SEND2RAZLADKI_PATH = 'send2razladki.py' if os.path.exists('send2razladki.py') else '../../yweb/blender/scripts/tools/send2razladki.py'

    def __init__(self):
        self.do_recalculate = self.DO_RECALCULATE

    @classmethod
    def _str2date(cls, s):
        return datetime.datetime.strptime(s, cls.DATE_INPUT_FORMAT)

    @classmethod
    def _date2str(cls, d):
        return d.strftime(cls.DATE_INPUT_FORMAT)

    @classmethod
    def _get_dates_from_interval(cls, start, end):
        s = cls._str2date(start)
        e = cls._str2date(end)
        assert s <= e, 'start date should be less than end date'
        while s <= e:
            yield cls._date2str(s)
            s += datetime.timedelta(days=1)

    @classmethod
    def _get_weeks_from_interval(cls, start, end):
        week = []
        for date in cls._get_dates_from_interval(start, end):
            week.append(date)
            d = cls._str2date(date)
            if len(week) == 7 or d.isoweekday() == 7:
                yield week
                week = []
        if week:
            yield week

    def eval_serp_element_skeletons(self, html_table, raw_skel_table):
        serp_parser_combiner = SerpParserCombiner(self.SERP_SETTINGS, raise_exceptions=False, verbose=False,
                                                  fields_to_out=(SerpParserCombiner.FIELD_RAW_SKELETON,))

        dstTables = [
            raw_skel_table + "_parsed",
            raw_skel_table + "_skel",
            raw_skel_table + "_serponly",
            raw_skel_table
        ]
        MapReduce.dropTables(dstTables)
        MapReduce.runCombine(
            serp_parser_combiner,
            srcTable=html_table,
            dstTables=dstTables,
            auxExecArguments=['-ytspec', '{"data_size_per_job": 16777216, "title": "SerpParserCombiner anatomy_razladki"}'],
        )

    def skeleton_uniq_query_reducer(self, key, recs):
        skeleton_md5_set = set()
        query = key['value']
        for rec in recs:
            skeleton_md5 = rec['key']
            if skeleton_md5 not in skeleton_md5_set:
                yield dict(query=query, skeleton_md5=skeleton_md5)
            skeleton_md5_set.add(skeleton_md5)

    class EUIType:
        # uitype in nano_sessions
        # https://arc.yandex-team.ru/wsvn/arc/trunk/arcadia/quality/user_sessions/amon/request.h#l24
        _ARR = [
            'DESKTOP',
            'PAD',
            'MOBILE',
            'TOUCH',
            'MOBILE_APP'
        ]

        @classmethod
        def int_code2str(cls, int_code):
            return cls._ARR[int_code]

    def nano_sessions_extract_query_mapper(self, rec):
        nano_session = json.loads(rec['value'])
        assert isinstance(nano_session, dict)
        platform = self.EUIType.int_code2str(nano_session['ui'])
        query_text = nano_session['corrected_query']
        if platform != 'TOUCH':
            return
        yield dict(query=query_text)

    def key_freq_reducer(self, key, recs):
        count = 0
        for rec in recs:
            count += rec.get('count', 1)
        ret = dict(key)
        ret['count'] = count
        yield ret

    def get_razladki_key_values(self, feature_freq_table, query_count_table):
        sum_beast_query_count = 0
        sum_yandex_query_count = 0
        for rec in yt.read_table(query_count_table):
            sum_beast_query_count = rec['beast_query_count']
            sum_yandex_query_count = rec['yandex_query_count']

        generalized_feature_name2beast_query_count = Counter()
        generalized_feature_name2yandex_query_count = Counter()
        for rec in yt.read_table(feature_freq_table):
            feature_name = rec['feature_name'].decode('utf8')
            distinct_beast_query_count = rec['distinct_beast_query_count']
            yandex_query_count = rec['yandex_query_count']
            generalized_feature_name2beast_query_count[feature_name] += distinct_beast_query_count
            generalized_feature_name2yandex_query_count[feature_name] += yandex_query_count
            if ':' in feature_name:
                pos = 0
                while pos >= 0:
                    generalized_feature_name = feature_name[:pos].strip()
                    pos = feature_name.find(':', pos + 1)
                    if not generalized_feature_name:
                        continue
                    generalized_feature_name2beast_query_count[generalized_feature_name] += distinct_beast_query_count
                    generalized_feature_name2yandex_query_count[generalized_feature_name] += yandex_query_count

        transliterator = Transliterator()
        razladki_key_values = []
        for feature_name in generalized_feature_name2yandex_query_count.keys():
            razladki_feature_name = transliterator.transliterate(feature_name)
            razladki_key_values.append((
                self.RAZLADKI_PREFIX + '_' + razladki_feature_name + '_beast_query_count',
                generalized_feature_name2beast_query_count[feature_name]))
            razladki_key_values.append((
                self.RAZLADKI_PREFIX + '_' + razladki_feature_name + '_yandex_query_count',
                generalized_feature_name2yandex_query_count[feature_name]))
            razladki_key_values.append((
                self.RAZLADKI_PREFIX + '_' + razladki_feature_name + '_beast_coverage',
                generalized_feature_name2beast_query_count[feature_name] / (sum_beast_query_count + 1E-12)
            ))
            razladki_key_values.append((
                self.RAZLADKI_PREFIX + '_' + razladki_feature_name + '_yandex_coverage',
                generalized_feature_name2yandex_query_count[feature_name] / (sum_yandex_query_count + 1E-12)
            ))
        razladki_key_values.sort()
        return razladki_key_values

    def send2razladki(self, date, razladki_key_values):
        razladki_key_values.sort()

        project_name = 'serp_anatomy'
        timestamp = str(int(time.mktime(self._str2date(date).timetuple())))
        send2razladki_line = '\t'.join([timestamp, project_name] + [key + ' ' + str(val) for key, val in razladki_key_values])

        logging.info('call python ' + self.SEND2RAZLADKI_PATH)
        process = subprocess.Popen('python ' + self.SEND2RAZLADKI_PATH, shell=True, stdin=subprocess.PIPE)
        process.stdin.write(send2razladki_line)
        process.stdin.write('\n')
        process.stdin.close()
        retcode = process.wait()
        if retcode:
            raise subprocess.CalledProcessError(retcode, self.SEND2RAZLADKI_PATH)

    def run_yql_query(self, sql):
        logging.info(sql)
        if 'YQL_TOKEN' in os.environ:
            client = YqlClient(db='hahn')
        else:
            token_path = os.environ['HOME'] + '/.yql/token'
            assert os.path.exists(token_path)
            client = YqlClient(db='hahn', token_path=token_path)
        request = client.query(sql)
        request.run()
        if not request.get_results().is_success:
            raise Exception('YQL-backend failed ' + repr(request.get_results()))

    def eval_razladki_for_week(self, week):
        week_date = week[-1]
        html_table = self.MR_DIR + week_date + '_html'
        if not yt.exists(html_table) or self.do_recalculate:
            beast_sample = BeastSample(self.do_recalculate, html_table)
            beast_table_list = [beast_sample.BEAST_TABLE_TEMPLATE.format(date=date) for date in week]
            beast_table_list = filter(yt.exists, beast_table_list)
            if beast_table_list:
                logging.info('filter beast: %s -> %s' % (repr(beast_table_list), html_table))
                beast_sample.run_filter_beast(beast_table_list, html_table)
            else:
                yt.create_table(html_table)

        raw_skel_table = self.MR_DIR + week_date + '_raw_skel'
        if not yt.exists(raw_skel_table) or self.do_recalculate:
            self.eval_serp_element_skeletons(html_table, raw_skel_table)
            if not yt.exists(raw_skel_table):
                yt.create_table(raw_skel_table)

        skel_query_table = self.MR_DIR + week_date + '_skel_query'
        if not yt.exists(skel_query_table) or self.do_recalculate:
            yt.run_sort(raw_skel_table, sort_by='value')
            yt.create_table(skel_query_table,
                recursive=True,
                attributes={
                    "schema": [
                        {"name": "skeleton_md5", "type": "string"},
                        {"name": "query", "type": "string"}
                    ],
                    "strict": True
                })
            yt.run_reduce(self.skeleton_uniq_query_reducer,
                source_table=raw_skel_table,
                destination_table=skel_query_table,
                reduce_by='value',
                spec=dict(title='skeleton_uniq_query_reducer')
            )

        yandex_query_freq_table = self.MR_DIR + week_date + '_yandex_query_freq'
        if not yt.exists(yandex_query_freq_table) or self.do_recalculate:
            yt.create_table(yandex_query_freq_table,
                recursive=True,
                attributes={
                    "schema": [
                        {"name": "query", "type": "string"},
                        {"name": "count", "type": "uint64"}
                    ],
                    "strict": True
                })
            nano_sessions_table_list = []
            for date in week:
                date_yt = date[:4] + '-' + date[4:6] + '-' + date[6:]
                table_name = self.NANO_SESSIONS_TABLE_PATTERN.format(date_yt=date_yt)
                if yt.exists(table_name):
                    nano_sessions_table_list.append(table_name)
            if nano_sessions_table_list:
                yt.run_map_reduce(
                    mapper=self.nano_sessions_extract_query_mapper,
                    reducer=self.key_freq_reducer,
                    reduce_combiner=self.key_freq_reducer,
                    source_table=nano_sessions_table_list,
                    destination_table=yandex_query_freq_table,
                    reduce_by='query',
                    spec=dict(title='nano_sessions_extract_query')
                )

            for table_name in yt.list(self.MR_DIR.rstrip('/')):
                if table_name.startswith(week_date):
                    yt.run_merge(self.MR_DIR + table_name, self.MR_DIR + table_name, spec={"combine_chunks": "true"})

        feature_query_freq_table = self.MR_DIR + week_date + '_feature_query_freq'
        if not yt.exists(feature_query_freq_table) or self.do_recalculate:
            skeleton_feature_nocollision_table = self.MR_DIR + 'skeleton_feature_nocollision'
            sql = """
                USE hahn;
                INSERT INTO [{feature_query_freq_table}]
                SELECT DISTINCT feature_id, feature_name, query, yandex_query_count
                  FROM (
                        SELECT
                            COALESCE(skf.feature_id, -1) as feature_id,
                            COALESCE(skf.feature_name, 'unknown') as feature_name,
                            sq.query as query,
                            COALESCE(yq.count, 1) as yandex_query_count
                          FROM [{skel_query_table}] as sq
                               LEFT OUTER JOIN [{skeleton_feature_nocollision_table}] as skf on skf.skeleton_md5=sq.skeleton_md5
                               LEFT OUTER JOIN [{yandex_query_freq_table}] as yq on yq.query=sq.query
                    )
                ;
            """
            sql = sql.strip()
            sql = sql.replace('{feature_query_freq_table}', feature_query_freq_table)
            sql = sql.replace('{skel_query_table}', skel_query_table)
            sql = sql.replace('{skeleton_feature_nocollision_table}', skeleton_feature_nocollision_table)
            sql = sql.replace('{yandex_query_freq_table}', yandex_query_freq_table)
            self.run_yql_query(sql)

        feature_freq_table = self.MR_DIR + week_date + '_feature_freq'
        if not yt.exists(feature_freq_table) or self.do_recalculate:
            sql = """
                USE hahn;
                INSERT INTO [{feature_freq_table}]
                SELECT feature_id, feature_name,
                    COUNT(1) as distinct_beast_query_count,
                    SUM(yandex_query_count) as yandex_query_count
                  FROM [{feature_query_freq_table}]
                  GROUP BY feature_id, feature_name
                ;
            """
            sql = sql.strip()
            sql = sql.replace('{feature_freq_table}', feature_freq_table)
            sql = sql.replace('{feature_query_freq_table}', feature_query_freq_table)
            self.run_yql_query(sql)

        query_count_table = self.MR_DIR + week_date + '_query_count'
        if not yt.exists(query_count_table) or self.do_recalculate:
            sql = """
                USE hahn;

                INSERT INTO [{query_count_table}]
                SELECT COUNT(1) as beast_query_count,
                      SUM(COALESCE(yq.count, 1)) as yandex_query_count
                  FROM (SELECT key as query FROM [{html_table}]) as bq
                    LEFT OUTER JOIN [{yandex_query_freq_table}] as yq on yq.query=bq.query
                ;
            """
            sql = sql.strip()
            sql = sql.replace('{query_count_table}', query_count_table)
            sql = sql.replace('{html_table}', html_table)
            sql = sql.replace('{yandex_query_freq_table}', yandex_query_freq_table)
            self.run_yql_query(sql)

        razladki_sent_flag_table = self.MR_DIR + week_date + '_razladki_sent_flag'
        if not yt.exists(razladki_sent_flag_table) or self.do_recalculate:
            logging.info('get_razladki_key_values')
            razladki_key_values = self.get_razladki_key_values(feature_freq_table, query_count_table)
            self.send2razladki(week_date, razladki_key_values)
            yt.create_table(razladki_sent_flag_table)

    @staticmethod
    def _download_file(path, destination_path):
        with open(destination_path, "wb") as f:
            for chunk in chunk_iter_stream(yt.read_file(path), 16 * 1024 * 1024):
                f.write(chunk)

    def eval_tag_skeletons(self):
        backup_dir = 'backup_dir'
        if not os.path.exists(backup_dir) or self.do_recalculate:
            logging.info('getting backup_dir from yt ' + self.BACKUP_DIR)
            backup_items = [item for item in yt.list(self.BACKUP_DIR) if item.startswith('backup_large_')]
            assert backup_items, 'no backup found at %s' % self.BACKUP_DIR
            last_backup = sorted(backup_items)[-1]
            backup_file = backup_dir + '/backup.tgz'
            os.mkdir(backup_dir)
            self._download_file(self.BACKUP_DIR + '/' + last_backup, backup_file)
            subprocess.check_call('tar xzf ' + backup_file + ' -C ' + backup_dir, shell=True)

        if not os.path.exists(self.SKELETON_FEATURES_FNAME) or self.do_recalculate:
            logging.info('parse html, evaluate skeleton_md52feature_id')
            conn = sqlite3.connect(backup_dir + '/data/competitors.db')
            sql = """
                SELECT g.feature_id, f.name, ex.serp_url_static, ex.element_id
                  FROM main_feature f
                       JOIN main_elementsgroup g on g.feature_id=f.id
                       JOIN main_example ex on ex.elements_group_id=g.id
            """
            cur = conn.cursor()
            serp_parser = SerpParser(self.SERP_SETTINGS)
            skeleton_evaluator = SkeletonEvaluator(self.SERP_SETTINGS)
            with codecs.open(self.SKELETON_FEATURES_FNAME, 'wb', encoding='utf8') as fo:
                skeleton_md5_feature_id_set = set()
                for lineno, (feature_id, feature_name, serp_url_static, seanid) in enumerate(cur.execute(sql)):
                    if lineno % 100 == 0:
                        logging.info('\tline ' + '\t'.join(map(unicode, (lineno, feature_id, feature_name))))
                    html_fname = backup_dir + '/media/' + serp_url_static
                    with open(html_fname) as f:
                        html = f.read()
                    dom_element = serp_parser.find_element_by_seanid(html, seanid)
                    if not dom_element:
                        logging.warn('no element with seanid=%s in %s' % (seanid, html_fname))
                        continue
                    skeleton = skeleton_evaluator.eval_skeleton(dom_element)
                    skeleton_md5 = serp_parser.eval_skeleton_md5(skeleton)
                    if (skeleton_md5, feature_id) not in skeleton_md5_feature_id_set:
                        print >>fo, '\t'.join(map(unicode, (skeleton_md5, feature_id, feature_name, serp_url_static, seanid)))
                    skeleton_md5_feature_id_set.add((skeleton_md5, feature_id))

        skeleton_feature_table = self.MR_DIR + 'skeleton_feature'
        if not yt.exists(skeleton_feature_table) or self.do_recalculate:
            yt.create_table(skeleton_feature_table,
                recursive=True,
                attributes={
                    "schema": [
                        {"name": "skeleton_md5", "type": "string"},
                        {"name": "feature_id", "type": "uint64"},
                        {"name": "feature_name", "type": "string"},
                        {"name": "serp_url_static", "type": "string"},
                        {"name": "seanid", "type": "string"},
                    ],
                    "strict": True
                })
            data = []
            with open(self.SKELETON_FEATURES_FNAME) as f:
                for l in f:
                    skeleton_md5, feature_id, feature_name, serp_url_static, seanid = l.strip().split('\t')
                    data.append(dict(
                        skeleton_md5=skeleton_md5,
                        feature_id=int(feature_id),
                        feature_name=feature_name,
                        serp_url_static=serp_url_static,
                        seanid=seanid
                    ))
            yt.write_table(skeleton_feature_table, data)

        # filter out skeleton_md5 which have more than one feature
        skeleton_feature_nocollision_table = self.MR_DIR + 'skeleton_feature_nocollision'
        if not yt.exists(skeleton_feature_nocollision_table) or self.do_recalculate:
            sql = """
                USE hahn;
                INSERT INTO [{skeleton_feature_nocollision_table}]
                SELECT t.skeleton_md5 as skeleton_md5, feature_id, feature_name, serp_url_static, seanid
                  FROM [{skeleton_feature_table}] as t
                      JOIN (
                          SELECT skeleton_md5
                            FROM [{skeleton_feature_table}]
                            GROUP BY skeleton_md5
                            HAVING COUNT(1)=1
                        ) as sk on t.skeleton_md5=sk.skeleton_md5
                ;
            """
            sql = sql.strip()
            sql = sql.replace('{skeleton_feature_nocollision_table}', skeleton_feature_nocollision_table)
            sql = sql.replace('{skeleton_feature_table}', skeleton_feature_table)
            self.run_yql_query(sql)

    def main(self):
        self.eval_tag_skeletons()
        date_from, date_to = '20161006', '20170312'
        week_list = list(self._get_weeks_from_interval(date_from, date_to))
        for week in week_list:
            self.eval_razladki_for_week(week)
        # pool = multiprocessing.Pool(10)
        # pool.map(_eval_razladki_for_week, week_list)
        logging.info('done.')


def _eval_razladki_for_week(week):
    AnatomyRazladki().eval_razladki_for_week(week)


if __name__ == '__main__':
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s\t%(levelname)s\t%(threadName)s\t%(msg)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        stream=sys.stdout,
    )
    # logging.getLogger("Yt").setLevel(level=logging.WARN)
    yt_main.logger.LOGGER.setLevel(level=logging.WARN)
    MapReduce.useDefaults(server="hahn",
                          mrExec="mapreduce-yt",
                          verbose=True)
    AnatomyRazladki().main()
