# coding: utf-8

import textwrap
import logging
import datetime
import time
import json

import yt.wrapper as yt
import yt.yson as yson

from irt.bannerland.options import get_option as get_bl_opt
from bm.yt_tools import columns_to_schema, set_attribute
from bannerland.yql.tools import do_yql

from bannerland.archive_workers.common import BLYTWorker, convert_offer_source_str, convert_task_feed_type_str
from bannerland.archive_workers.publish import PublishWorker
from bannerland.archive_workers.transfer import TransferWorker
from bannerland.common import get_display_info, get_image_info_direct_format, get_model_info
import irt.bannerland.switcher_filter
import bannerland.validate_banners
import bannerland.video_creative

from irt.bannerland.proto.v1.bannerland import phrase_pb2 as phrase_proto


def fix_price(price):
    if not price:
        return

    if isinstance(price, str):
        price = price.replace(',', '.')
    try:
        price = float(price)
    except ValueError:
        logging.warning('incorrect price format: %s', price)
        return

    return price


def _prepare_avatars_meta(row, avatars_len):
    if row.get('AvatarsMdsMeta') is not None:
        return [{'meta': meta} for meta in json.loads(row['AvatarsMdsMeta'])]
    else:
        return [None] * avatars_len


def fix_domain_fields(row):
    if row.get('Site') is not None:
        row['SiteID'] = row['SiteFilterID']
        row['OrigDomain'] = row['Site']
        row['OrigDomainID'] = row['SiteFilterID']
        row['Domain'] = row['TargetDomain']
        row['DomainID'] = row['TargetDomainID']
    else:
        # TODO(i-gataullin) remove after Site in pocket
        row['Site'] = row['OrigDomain']
        row['SiteID'] = row['OrigDomainID']
        row['SiteFilter'] = row['OrigDomain']
        row['SiteFilterID'] = row['OrigDomainID']
        row['TargetDomain'] = row['Domain']
        row['TargetDomainID'] = row['DomainID']


def fix_perf_banner_for_caesar(row):
    # TODO(malykhin): use proto enum DYNSMART-1434
    lang_id = {'ru': 1, 'tr': 6, 'en': 3}

    bs_info = json.loads(row['BSInfo'])
    model_card = json.loads(row['ModelCard'])

    row['Href'] = row['Url']
    row['HrefText'] = bs_info.get('display_href', '')
    row['TemplateID'] = 737  # TODO(malykhin): decide where to write template ids DYNSMART-1433
    row['LangID'] = lang_id[row['Lang']]
    row['Flags'] = ''
    row['FlagIDs'] = None  # TODO(malykhin): get flags from row['Flags'].split(',') and write ids! MODDEV-2426
    row['CategoryIDs'] = [yson.YsonUint64(ctg) for ctg in row['Categories'].split(',') if ctg != '']

    # IRTDUTY-92 fix
    if model_card['counter_id'] == 'None':
        model_card['counter_id'] = '0'
    row['CounterID'] = int(model_card['counter_id'])

    if not row.get('ClientID'):
        row['ClientID'] = int(bs_info['text']['client_id'])

    # complex columns
    row['CalloutSet'] = {'CalloutsList': bs_info['callouts_list']}

    direct_price_info_fields_mapping = {
        'current': 'Price',
        'old': 'OldPrice',
    }
    price = bs_info.get('price', {})
    direct_price_info = {}

    for k in direct_price_info_fields_mapping:
        if k in price and fix_price(price[k]):
            direct_price_info[direct_price_info_fields_mapping[k]] = price[k]

    if 'currency_iso_code' in bs_info['text']:
        direct_price_info['Currency'] = bs_info['text']['currency_iso_code']

    row['BannerPrice'] = direct_price_info

    if row.get('avatars'):
        avatars = json.loads(row['avatars'])
        avatars_meta = _prepare_avatars_meta(row, len(avatars))
        row['ImagesInfoDirectFormat'] = [get_image_info_direct_format(ava, meta) for ava, meta in zip(avatars, avatars_meta)]
    row['DisplayInfo'] = get_display_info(bs_info)
    row['ModelInfo'] = get_model_info(model_card)
    offer_info = row.get('OfferInfo') or {}
    offer_info.pop('colors', None)  # TODO(optozorax): remove when frontend can accept new color format
    if row['ModelInfo']['AdvType'] != 'clothes':
        offer_info.pop('sizes', None)
    row['DisplayInfo'].pop('sizes', None)
    row['DisplayInfo'].update(offer_info)
    row['OfferSource'] = convert_offer_source_str(row['BLBannerDetails'].get('offer_source', 'feed'))
    if 'AutogeneratedOfferID' not in row:
        row['AutogeneratedOfferID'] = None

    if bannerland.video_creative.check_order(row['OrderID']):
        row['VideoCreative'] = bannerland.video_creative.get_creative(row['OrderID'], row['Href'])


def fix_dyn_banner_for_caesar(row):
    # TODO(malykhin): use proto enum DYNSMART-1434
    lang_id = {'ru': 1, 'tr': 6, 'en': 3}

    row['Href'] = row['Url']
    row['HrefText'] = row['UrlText']
    row['LangID'] = lang_id[row['Lang']]
    row['Flags'] = ''
    row['FlagIDs'] = None  # TODO(malykhin): get flags from row['Flags'].split(',') and write ids! MODDEV-2426
    row['CategoryIDs'] = [yson.YsonUint64(ctg) for ctg in row['Categories'].split(',') if ctg != '']

    direct_price_info = {}
    Info = json.loads(row['Info'])
    if row['Price']:
        direct_price_info['Price'] = row['Price']
        oldprice = Info.get('oldprice')
        if oldprice:
            direct_price_info['OldPrice'] = oldprice

    direct_price_info['Currency'] = row['Currency']
    row['BannerPrice'] = direct_price_info

    if row.get('avatars'):
        avatars = json.loads(row['avatars'])
        avatars_meta = _prepare_avatars_meta(row, len(avatars))
        row['ImagesInfoDirectFormat'] = [get_image_info_direct_format(ava, meta) for ava, meta in zip(avatars, avatars_meta)]

    row['DisplayInfo'] = (row.get('OfferInfo') or {})
    if Info.get('adv_type') != 'clothes':
        row['DisplayInfo'].pop('sizes', None)
    row['DisplayInfo'].pop('colors', None)

    row['OfferSource'] = convert_offer_source_str(row['BLBannerDetails'].get('offer_source', 'feed'))


def fix_phrase_for_caesar(row):
    norm_type_id = {
        'norm':  phrase_proto.ENormType.Norm,
        'snorm': phrase_proto.ENormType.Snorm,
        'offer': phrase_proto.ENormType.Offer,
        'offer_group': phrase_proto.ENormType.OfferGroup,
    }
    row['NormTypeID'] = norm_type_id.get(row['NormType'], phrase_proto.ENormType.NoneType)


# - определяем pocket дельты по new_pockets (новые таблицы идут в начале списка), дописываем в row
# - убираем карманы, не вошедшие в active_pockets
# - убираем данные не из последнего кармана
@yt.with_context
class ActivePocketsMapper:
    def __init__(self, task_type, active_pockets, new_pockets, task_last_pocket, fullstate_table_index, task_feed_types):
        self.task_type = task_type
        self.active_pockets_set = set(active_pockets)
        self.new_pockets = new_pockets
        self.task_last_pocket = task_last_pocket
        self.fullstate_table_index = fullstate_table_index
        self.banned_simple_generation_orders = set(get_bl_opt('banned_simple_generation_orders'))
        self.task_feed_types = task_feed_types

    def __call__(self, row, context):
        if 'pocket' not in row:
            pocket_no = context.table_index
            row['pocket'] = self.new_pockets[pocket_no]

        pocket = row['pocket']
        do_yield = False
        if pocket in self.active_pockets_set:
            tid = row['task_id']
            if tid not in self.task_last_pocket:
                if context.table_index == self.fullstate_table_index:
                    # данные из старого FS, в дельте их нет, оставляем
                    do_yield = True
            elif self.task_last_pocket[tid] == pocket:
                # последняя версия, из дельты
                do_yield = True
            # else: устаревшие данные, не берём

        if do_yield:
            # legacy patch, allows to merge old patches in full state
            fix_domain_fields(row)
            if 'Href' not in row or ('OfferSource' not in row and self.task_type == 'perf'):
                if self.task_type == 'perf':
                    fix_perf_banner_for_caesar(row)
                    fix_phrase_for_caesar(row)
                elif self.task_type == 'dyn':
                    fix_dyn_banner_for_caesar(row)
                    fix_phrase_for_caesar(row)

            feed_type = self.task_feed_types.get(row['task_id'], None)
            row['InGrut'] = irt.bannerland.switcher_filter.is_row_for_grut(row, self.task_type, feed_type)

            if self.task_type == 'perf' and (not row['avatars'] or row['avatars'] == '[]'):
                return

            if self.task_type == 'perf':
                row.pop('Info', None)

            if row['BannerPrice'].get('OldPrice') and float(row['BannerPrice'].get('Price', '0')) > float(row['BannerPrice'].get('OldPrice')):
                del row['BannerPrice']['OldPrice']

            if (row['BLBannerDetails'].get('title_source') == 'simple') and (row['OrderID'] in self.banned_simple_generation_orders):
                return

            # IRTDUTY-207: kostyl, drop bad banners (with phrases in Title)
            if (row['BLBannerDetails'].get('title_source') == 'simple') and ('phrases' in row['BLBannerDetails'].get('title_template', '')):
                return

            yield row


class AddToFullStateWorker(BLYTWorker):
    def __init__(self, task_type, active_days_count, **kwargs):
        super(AddToFullStateWorker, self).__init__(task_type, **kwargs)

        self.active_days_count = active_days_count
        self.pocket_list_attr = 'bannerland_pocket_list'

        if task_type == 'perf':
            self.fs_schema = columns_to_schema(get_bl_opt('perf_result_columns'))
        elif task_type == 'dyn':
            self.fs_schema = columns_to_schema(get_bl_opt('dyn_result_columns'))
        else:
            raise ValueError('task_type not found')

    # оставляем карманы так, чтобы количество присутствующих дней было не меньше заданного
    def get_active_pockets(self, pocket_list, active_days_count=None):
        if active_days_count is None:
            active_days_count = self.active_days_count
        dt_format = get_bl_opt('bannerland_pocket_name_format')
        pocket_day = {}
        for pocket in pocket_list:
            pocket_day[pocket] = datetime.datetime.strptime(pocket, dt_format).strftime("%Y%m%d")
        all_days = set(pocket_day.values())
        last_days = set(sorted(all_days, reverse=True)[:active_days_count])
        pockets = [p for p in pocket_list if pocket_day[p] in last_days]
        return pockets

    # низкоуровневая функция вливания карманов в fs, для прода, тестов, нештатных запусков
    def merge_fs(self, pocket_dir_list, input_fs, output_fs, active_days_count=None):
        if active_days_count is None:
            active_days_count = self.active_days_count
        yt_client = self.yt_client

        # для каждой таски из дельты находим наиболее свежий карман
        # здесь нужно в памяти держать словарь для всех тасок из дельты(!), сейчас их 5-10k, запас очень большой
        # зато это позволяет обойтись одним маппером, без редьюса по task_id, за десятки минут можно собрать fs за месяц
        task_last_pocket = {}
        task_feed_types = {}
        for pocket_dir in sorted(pocket_dir_list, reverse=True):
            pocket = pocket_dir.split('/')[-1]
            tasks_table_view = yt.TablePath(pocket_dir + '/tasks.final', columns=['task_id', 'task_inf'])
            for row in yt_client.read_table(tasks_table_view):
                tid = row['task_id']
                if tid not in task_last_pocket:
                    task_last_pocket[tid] = pocket
                task_inf_json = json.loads(row['task_inf'])
                if 'LastValidFeedType' in task_inf_json['Resource']:
                    task_feed_types[tid] = convert_task_feed_type_str(task_inf_json['Resource']['LastValidFeedType'])
        new_tasks_count = len(list(filter(lambda x: task_last_pocket[x] > '0', task_last_pocket)))
        logging.info('tasks in new pockets: %d', new_tasks_count)

        new_pocket_tables = []
        new_pockets = []
        for pocket_dir in pocket_dir_list:
            new_pocket_tables.append(pocket_dir + '/generated_banners.final')
            pocket = pocket_dir.split('/')[-1]
            new_pockets.append(pocket)
        logging.info('new pocket tables: %s', new_pocket_tables)

        pocket_list_attr = self.pocket_list_attr
        old_pockets = yt_client.get_attribute(input_fs, pocket_list_attr, [])
        if set(new_pockets) & set(old_pockets):
            raise Exception("some new pocket already in fs!")

        active_pockets = self.get_active_pockets(
            pocket_list=old_pockets + new_pockets,
            active_days_count=active_days_count,
        )
        logging.info('active pockets: %s', active_pockets)

        mapper = ActivePocketsMapper(
            task_type=self.task_type,
            active_pockets=active_pockets,
            new_pockets=new_pockets,
            task_last_pocket=task_last_pocket,
            fullstate_table_index=len(new_pockets),
            task_feed_types=task_feed_types
        )

        output_fields = [col['name'] for col in self.fs_schema]
        input_tables = new_pocket_tables + [input_fs]
        input_table_paths = [yt.TablePath(table, columns=output_fields) for table in input_tables]

        with yt_client.Transaction():
            yt_client.create('table', output_fs, attributes={'schema': self.fs_schema, 'optimize_for': 'scan'})
            yt_client.run_map(
                mapper,
                input_table_paths,
                output_fs,
                spec={
                    'mapper': {
                        'memory_reserve_factor': 0.8,
                        'memory_limit': 500000000,
                    },
                },
            )
            yt_client.set_attribute(output_fs, pocket_list_attr, list(active_pockets))

    def do_work(self, pocket_dir):
        logging.info('do_work for %s', pocket_dir)
        yt_client = self.yt_client

        cypress_conf = self.get_cypress_config()
        fs_root = cypress_conf.get_path('full_state_dir')
        fs_archive = cypress_conf.get_path('full_state_archive')

        fs_dirs = yt_client.list(fs_archive, absolute=True)
        if fs_dirs:
            current_dir = max(fs_dirs)
            current_fs = yt.ypath_join(current_dir, 'bannerphrases')
        else:
            # для первого запуска создаем в дериктории с карманом пустой fs с которым запустим join
            current_dir = pocket_dir
            current_fs = yt.ypath_join(current_dir, 'empty_fs')
            yt_client.create('table', current_fs, attributes={'schema': self.fs_schema}, ignore_existing=True)
            yt_client.set_attribute(current_fs, self.pocket_list_attr, [])  # не обязательно

        new_dir = yt.ypath_join(fs_root, 'new')
        yt_client.remove(new_dir, force=True, recursive=True)
        yt_client.create('map_node', new_dir)

        new_fs = yt.ypath_join(new_dir, 'bannerphrases')
        self.merge_fs(
            pocket_dir_list=[pocket_dir],
            input_fs=current_fs,
            output_fs=new_fs,
        )

        # здесь можно будет ограничивать Order-ы по кол-ву баннеров
        pass

        # ставим линк на предыдущий FS, т.к. может потребоваться посчитать дифф между старым и новым
        # например, новые фразы для Бордоноса
        yt_client.link(current_dir, yt.ypath_join(new_dir, 'prev_fs_dir'))

        # таймстемпы можно писать либо в человеко-читаемом виде, либо через unixtime
        # на unixtime смотреть неудобно, поэтому здесь выбран первый вариант
        # однако он не гарантирует монотонности (переход на зимнее/летнее время)
        # в этом случае можно пропустить обновление (час в полгода), либо прибавлять секунду к предыдущему
        fs_dirname = datetime.datetime.now().strftime(self.dt_format)
        final_dir = yt.ypath_join(fs_archive, fs_dirname)
        if current_dir != pocket_dir and final_dir < current_dir:
            logging.warning("Time of new fs is less than old fs, can't create new table, maybe next time!")
            return

        yt_client.move(new_dir, final_dir)
        yt_client.set_attribute(final_dir, 'last_pocket', pocket_dir)


# архивариусы в архиве фул стейта
class FSWorker(TransferWorker):
    final_table_name = 'bannerphrases.final'
    export_version_attr = 'bannerland_export_version'
    final_dir_name = 'final'
    full_state_id_attr = 'FullStateID'


class FinalizeFSWorker(FSWorker):
    def __init__(self, input_name, **kwargs):
        super(FinalizeFSWorker, self).__init__(**kwargs)
        self.input_name = input_name

    def do_work(self, fs_dir):
        if self.task_type == 'perf':
            curr_version = 'v2'
        else:
            curr_version = 'v1'

        yt_client = self.yt_client
        fs_path = yt.ypath_join(fs_dir, self.final_table_name)

        if self.input_name != self.final_table_name:
            yt_client.move(yt.ypath_join(fs_dir, self.input_name), fs_path)
            yt_client.link(fs_path, yt.ypath_join(fs_dir, self.input_name))

        yt_client.set_attribute(fs_path, self.export_version_attr, curr_version)


class SeparateFSWorker(FSWorker):
    def select_and_group(self, input, output, fields, groupby_fields, minby_fields, rename=None, agr_fields=None, agr_res_field=None, agr_limit=50, agr_sort_fields=None):
        """Select and group by given columns; use MIN_BY for non-groupby fields."""
        # MIN_BY вместо SOME для избежания битых данных при коллизиях BannerID
        # SOME не гарантирует, что значения для разных колонок будут взяты из одной строки
        if rename is None:
            rename = {}

        def quote_field(field):
            return '`{}`'.format(field)

        what = []
        for src_fld in fields:
            dst_fld = rename.get(src_fld, src_fld)
            src_quoted_fld = quote_field(src_fld)
            dst_quoted_fld = quote_field(dst_fld)
            if src_fld in groupby_fields:
                if src_fld == dst_fld:
                    value = src_quoted_fld
                else:
                    value = '{src_field} AS {dst_field}'.format(src_field=src_quoted_fld, dst_field=dst_quoted_fld)
            else:
                value = 'MIN_BY({src_field}, {by_field}) as {dst_field}'.format(
                    src_field=src_quoted_fld,
                    dst_field=dst_quoted_fld,
                    by_field='({})'.format(','.join(quote_field(field) for field in minby_fields)),
                )
            what.append(value)

        if agr_fields and agr_res_field and agr_sort_fields:
            agr_fields_what = []
            sort_fields_convert = ', '.join(map(lambda field: 'Yson::ConvertTo($x["{name}"], {type})'.format(name=field[0], type=field[1]), agr_sort_fields))
            comparator = '($x) -> {{RETURN ({fields_list})}}'.format(fields_list=sort_fields_convert)
            for field in agr_fields:
                dst_field = rename.get(field, field)
                agr_fields_what.append('{field} as {dst_field}'.format(field=quote_field(field), dst_field=quote_field(dst_field)))
            yql_query_agr = textwrap.dedent("""
            ListSort(
                AGGREGATE_LIST(
                    Yson::Serialize(
                        Yson::From(
                            AsStruct(
                                {agr_fields}
                            )
                        )
                    )
                , {limit}), {comparator}
            ) as {res_field}
            """).format(
                agr_fields=',\n  '.join(agr_fields_what),
                res_field=quote_field(agr_res_field),
                limit=agr_limit,
                comparator=comparator
            )
            what.append(yql_query_agr)

        yql_query = textwrap.dedent("""
            PRAGMA yt.TemporaryAutoMerge="disabled";

            INSERT INTO `{output}` WITH TRUNCATE
            SELECT
              {what}
            FROM `{input}`
            GROUP BY {groupby}
            ORDER BY {groupby}
        """).format(
            input=input, output=output,
            what=',\n  '.join(what),
            groupby=','.join([quote_field(fld) for fld in groupby_fields]),
        )
        do_yql(self.yql_client, yql_query, yt_pool=self.yt_pool)

    def do_work(self, fs_dir):
        if self.task_type == 'perf':
            all_columns = get_bl_opt('perf_result_columns')
            new_avatars_name = 'Avatars'
        elif self.task_type == 'dyn':
            all_columns = get_bl_opt('dyn_result_columns')
            new_avatars_name = 'Images'
        else:
            raise ValueError('Task type not supported: {}'.format(self.task_type))

        fs_path = yt.ypath_join(fs_dir, self.final_table_name)
        dst_dir = yt.ypath_join(fs_dir, self.final_dir_name)
        self.yt_client.create('map_node', dst_dir, ignore_existing=True)

        fs_fields_set = {col['name'] for col in self.yt_client.get(fs_path + '/@schema')}
        all_columns = [col for col in all_columns if col['name'] in fs_fields_set]  # if fs has old schema
        agr_limit = get_bl_opt('max_phrases_per_banner').get(self.task_type)
        agr_fields = [col['name'] for col in all_columns if col.get('banner.BroadPhrases')]

        self.select_and_group(
            input=fs_path,
            output=yt.ypath_join(dst_dir, 'banners'),
            fields=[col['name'] for col in all_columns if col.get('banner')],
            groupby_fields=[col['name'] for col in all_columns if col.get('banner_key')],
            minby_fields=[col['name'] for col in all_columns if col.get('phrase_key')],
            rename={'avatars': new_avatars_name, 'NormTypeID': 'NormType'},
            agr_fields=agr_fields,
            agr_res_field='BroadPhrases',
            agr_limit=agr_limit,
            agr_sort_fields=[(col['name'], col['type']) for col in all_columns if col.get('phrase_key') and col['name'] in agr_fields],
        )

        self.select_and_group(
            input=fs_path,
            output=yt.ypath_join(dst_dir, 'phrases'),
            fields=[col['name'] for col in all_columns if col.get('phrase')],
            groupby_fields=[col['name'] for col in all_columns if col.get('phrase_key')],
            minby_fields=['bannerphrase_md5'],
        )
        set_attribute(dst_dir, self.export_version_attr, 'v3.1', yt_client=self.yt_client)


class PublishFSWorker(PublishWorker, FSWorker):
    def __init__(self, input_name=None, **kwargs):
        super(PublishFSWorker, self).__init__(**kwargs)
        self.input_name = input_name

    def do_work(self, fs_dir):
        yt_client = self.yt_client
        full_state_id = int(time.time())
        sensor_name = self.get_sensor_name()

        cypress_config = self.get_cypress_config()
        target_dir = cypress_config.get_path('full_state_export')

        logging.info('separate fs')
        final_dir = yt.ypath_join(fs_dir, self.final_dir_name)
        set_attribute(final_dir, self.full_state_id_attr, full_state_id, yt_client=yt_client)
        self.publish(final_dir, target_dir=target_dir, version_attr=self.export_version_attr, copy_attrs=[self.full_state_id_attr], sensor_name=sensor_name)

    def sort(self, fs_path):
        if self.task_type == 'perf':
            sort_fields = ['OrderID', 'GroupExportID', 'BannerID']
        else:
            sort_fields = ['OrderID', 'ParentExportID', 'BannerID']
        self.yt_client.run_sort(fs_path, sort_by=sort_fields)

    def get_sensor_name(self):
        return "{}.full_state.actuality".format(self.task_type)


class CheckFinalTableWorker(FSWorker):
    def do_work(self, fs_dir):
        final_table = yt.ypath_join(fs_dir, self.final_dir_name, 'banners')
        prev_final_table = yt.ypath_join(fs_dir, 'prev_fs_dir', self.final_dir_name, 'banners')
        prev_prev_final_table = yt.ypath_join(fs_dir, 'prev_fs_dir', 'prev_fs_dir', self.final_dir_name, 'banners')

        for old_table in [prev_final_table, prev_prev_final_table]:
            logging.info('comparing tables %s and %s ...', final_table, old_table)
            if bannerland.validate_banners.compare_banners_table_counters(final_table, old_table, self.yt_client):
                logging.info('compare ok')
            else:
                raise RuntimeError('compare failed: {} , {}'.format(final_table, old_table))
