# -*- coding: utf-8 -*-
import sys
import six
import os
import json
import enum
from calendar import timegm
from datetime import datetime
from collections import Counter
import logging
import time
import dateutil.parser

import yt.wrapper as yt
from yt.wrapper.cypress_commands import _KWARG_SENTINEL
import yt.yson as yson
import irt.iron.options as iron_opts
import irt.broadmatching.common_options

if six.PY3:
    long = int


def gen_packs(iterable, pack_size):
    pack = []
    for x in iterable:
        pack.append(x)
        if len(pack) >= pack_size:
            yield pack
            pack = []
    if pack:
        yield pack


def get_attribute(table, attribute, yt_client, default_value=_KWARG_SENTINEL):
    value = None
    try:
        value = yt_client.get(table, attributes=[attribute]).attributes.get(attribute, default_value)
    except Exception as e:
        logging.exception('Exception at get_attribute({},{})'.format(table, attribute), str(e))
        raise e
    if value == _KWARG_SENTINEL:
        raise RuntimeError('not found attr = {}'.format(attribute))
    return value


def set_attribute(table, attribute, value, yt_client):
    yt_client.set(yt.ypath_join(table, '@' + attribute), value)


def get_type(yt_object, yt_client=yt):
    return get_attribute('{}&'.format(yt_object), 'type', yt_client)


def parse_fmt_str(fmt_str):
    fields = []
    s2f = {
        "int": int,
        "int32": int,
        "uint32": int,
        "long": long,
        "int64": yson.YsonInt64,
        "uint64": yson.YsonUint64,
        "float": float,
        "str": str,
    }
    if fmt_str == "":
        return []
    for fmt in fmt_str.split(","):
        k, t_str = fmt.split(":")
        t = s2f.get(t_str)
        fields.append([k, t])
    return fields


class NormalizeBase(object):
    def __init__(self, local_files=False, remove_stop_words=True):
        self.remove_stop_words = remove_stop_words
        lib_dir = os.path.dirname(os.path.abspath(__file__)) + "/../../lib"
        sys.path.append(lib_dir)
        import irt.common.cnormalizer as cnormalizer
        self.cnormalizer = cnormalizer
        self.local_files = local_files

        # local_files
        if local_files:
            norm_dict = str(irt.broadmatching.common_options.get_options()["Words_params.norm_dict"])
            snorm_dict = str(irt.broadmatching.common_options.get_options()["Words_params.word2snorm_dict"])
            stw_dicts = {str(x[1]): str(x[0]) for x in irt.broadmatching.common_options.get_options()["Words_params.stopword_dicts"]}
            stop_dict = stw_dicts["ru"]

            self.files = []
            self._name = {}

            for n, f in {"norm_dict": norm_dict, "snorm_dict": snorm_dict, "stop_dict": stop_dict}.items():
                self.files.append(f)
                self._name[n] = os.path.basename(f)

        else:
            # YT files
            self._yt_name = {}
            self.yt_files = []
            all_norm_dicts = iron_opts.get("yt_files")["normalization_dicts"]
            for n, f in all_norm_dicts.items():
                self._yt_name[n] = f.split('/')[-1]  # local names in sandbox directory
                self.yt_files.append(f)

    def start(self):
        if self.local_files:
            self.normalizer = self.cnormalizer.Normalizer(self._name["norm_dict"])
            if self.remove_stop_words:
                self.normalizer.load_stop_words(self._name["stop_dict"], 'ru')
            self.snormalizer = self.cnormalizer.Normalizer(self._name["snorm_dict"])
            if self.remove_stop_words:
                self.snormalizer.load_stop_words(self._name["stop_dict"], 'ru')
        else:
            self.normalizer = self.cnormalizer.Normalizer(self._yt_name["norm_dict"])
            if self.remove_stop_words:
                self.normalizer.load_stop_words(self._yt_name["stop_dict"], 'ru')
            self.snormalizer = self.cnormalizer.Normalizer(self._yt_name["snorm_dict"])
            if self.remove_stop_words:
                self.snormalizer.load_stop_words(self._yt_name["stop_dict"], 'ru')

    def norm_phr(self, text, sort=True, uniq=False):
        return self.normalizer.normalize(text, "ru", uniq, sort)

    def snorm_phr(self, text, sort=True, uniq=False):
        return self.snormalizer.normalize(text, "ru", uniq, sort)


class NormalizeMapper(NormalizeBase):
    pass


class NormalizeReducer(NormalizeBase):
    pass


# inner join двух таблиц (точнее, первой и объединения остальных); строки для каждого ключа первой таблицы кладём в память
# пример: yt.run_reduce(JoinReducer(), [left_table, right_table_1, right_table_2], output_table, reduce_by=[..])
# параметры:
#   table_index_field  -  поле, в котором записан номер таблицы (default: '@table_index')
#       при вызове через run_map_reduce: нужно заполнять поле и указывать его, должна быть сортировка по reduce_by + это поле
#       имеет значение только равенство нулю, это определяет "первую" таблицу
#       поле будет удаляться
class JoinReducer(object):
    def __init__(self, table_index_field='@table_index'):
        self.table_index_field = table_index_field

    def __call__(self, key, rows):
        fld = self.table_index_field
        rows0 = []
        for row in rows:
            if row[fld] == 0:
                rows0.append(row)
                continue
            for row0 in rows0:
                d = row0.copy()
                d.update(row)
                del d[fld]
                yield d


# Mapper для сохранения table_index до этапа Reduce.
# Полезен для Join'а двух неотсортированных таблиц.
class TableIndexCopyMapper(object):
    def __call__(self, row):
        row['_table_index'] = row['@table_index']
        del row['@table_index']
        yield row


# inner/left join двух таблиц с поддержкой индекса таблицы в _table_index и без требования порядка следования строк
# Полезен для Join'а двух неотсортированных таблиц.
class TableIndexCompatibleJoinReducer(object):
    def __init__(self, type='inner'):
        if type not in ['inner', 'left']:
            raise Exception("Bad type of TableIndexCompatibleJoinReducer")
        self.type = type

    @staticmethod
    def delete_table_index(row):
        if '@table_index' in row:
            del row['@table_index']
        if '_table_index' in row:
            del row['_table_index']

    def __call__(self, key, rows):
        rows_by_ti = [[], []]
        for row in rows:
            if '_table_index' in row:
                rows_by_ti[row['_table_index']].append(row)
            else:
                rows_by_ti[row['@table_index']].append(row)

        for row0 in rows_by_ti[0]:
            if not rows_by_ti[1] and self.type == 'left':
                self.delete_table_index(row0)
                yield row0

            for row1plus in rows_by_ti[1]:
                d = {}
                d.update(row0)
                d.update(row1plus)
                self.delete_table_index(d)
                yield d


class FirstReducer(object):
    def __call__(self, key, rows):
        for row in rows:
            if "@table_index" in row:
                del row["@table_index"]
            yield row
            break


def get_yt_bm_config(production=False):
    env = os.environ
    proxy = env.get('YT_PROXY', iron_opts.get('yt_default_proxy'))

    token = None
    if 'YT_TOKEN' in env:
        token = env['YT_TOKEN']
    else:
        token_path = env.get('YT_TOKEN_PATH', iron_opts.get('yt_token_path'))
        with open(token_path) as fh:
            token = fh.read().strip()

    conf = {
        "proxy": {"url": proxy},
        "token": token,
        "read_retries": {"enable": True},
        "spec_defaults": {"mapper": {"tmpfs_path": "."}},
    }
    if production:
        # use production pool
        conf["spec_defaults"]["pool"] = "broadmatching"
        conf["operation_tracker"] = {"ignore_stderr_if_download_failed": True}

    return conf


# объединение схем нескольких таблиц, с проверкой типов, пока без учёта required (ставим required=False)
def join_schemas_from_tables(tables, yt_client=yt):
    schemas = [yt_client.get_attribute(table, 'schema') for table in tables]
    return join_schemas_from_list(schemas)


def join_schemas_from_list(schemas):
    col_names = []
    col_type = {}

    is_strict = True  # строгая, когда все таблицы строгие
    for input_schema in schemas:
        if not input_schema.attributes['strict']:
            is_strict = False
        for col in input_schema:
            name = col['name']
            if name not in col_names:
                col_names.append(name)
                col_type[name] = col['type']
            else:
                if col['type'] != col_type[name]:
                    raise Exception("Incompatible types for: " + name)

    columns = [{'name': n, 'type': col_type[n]} for n in col_names]
    schema = yson.YsonList(columns)
    schema.attributes['strict'] = is_strict
    return schema


def set_upload_time(table, yt_client=yt):
    yt_client.set_attribute(table, "upload_time", datetime.utcnow().isoformat())


def get_upload_time(table, yt_client=yt):
    # извлекаем unix-время в try-блоке на случай, если оно в неверном формате или YT-ёвый 'get_attribute' вернул None
    try:
        upload_time = yt_client.get_attribute(table, "upload_time", default=None)
        if upload_time is None:
            return None
        upload_datetime = dateutil.parser.parse(upload_time)
        offset = upload_datetime.utcoffset()
        upload_datetime = upload_datetime.replace(tzinfo=None)
        utc_upload_datetime = upload_datetime - offset if offset is not None else upload_datetime
        return (utc_upload_datetime - datetime(1970, 1, 1)).total_seconds()

    except Exception as err:
        print("ERROR: failed in reading 'upload_time' for YT-table '{}'\n{}".format(table, err))


# какой питоновский тип может хранить данные из колонки yt-таблицы с этим типом
def convert_yt_to_py(yt_type_name, use_yson=False):
    yt2py = {
        "int8": int,
        "int16": int,
        "int32": int,
        "int64": long,
        "uint8": int,
        "uint16": int,
        "uint32": long,
        "uint64": long,
        "double": float,
        "boolean": bool,
        "string": str,
        "any": json.loads
        # todo: типы "utf8", "any"
    }
    if use_yson:
        yt2py['uint64'] = yson.YsonUint64

    return yt2py[yt_type_name]


# в каком поле yt-таблицы можно сохранить данный питоновский тип
def convert_py_to_yt(py_type):
    py2yt = {
        'str': 'string',
        'bool': 'boolean',
        'float': 'double',
        'int':  'int64',
        'long': 'int64',
        'YsonUint64': 'uint64',
    }
    return py2yt.get(py_type.__name__, 'any')


# types - словарь имя:тип, где тип -- либо строка с названием типа в yt ("string"), либо питоновский тип (str)
def get_schema(types, strict=False):
    field_list = []
    for fld in sorted(types.keys()):
        tp = types[fld]
        if isinstance(tp, str):
            tp_str = tp  # yt-тип
        else:
            tp_str = convert_py_to_yt(tp)  # питон
        field_list.append({"name": fld, "type": tp_str})

    schema = yson.YsonList(field_list)
    schema.attributes["strict"] = strict
    return schema


def columns_to_schema(columns):
    required_keys = ['name', 'type']
    optional_keys = ['sort_order', 'required']
    schema = []
    for col in columns:
        scol = {f: col[f] for f in required_keys}
        scol.update({f: col[f] for f in optional_keys if f in col})
        schema.append(scol)
    return schema


def yt_time2ts(yt_time_str):
    return timegm(datetime.strptime(yt_time_str, "%Y-%m-%dT%H:%M:%S.%fZ").timetuple())


def get_mtime(path):
    modtime = yt.get_attribute(path, 'modification_time')
    return yt_time2ts(modtime)


def get_cdict_generation_params():
    return {
        'max_phrase_length': 500,   # максимальная длина фразы (с длиной больше игнорируем)
        'max_query_words': 6,       # максимальное число слов в запросе
        'max_subphrase_words': 6,   # максимальное число слов в генерируемых подфразах
        'max_generation_words': 6,  # максимальное количество слов во фразе, из которой генерируем подфразы (если больше, то транкируем)
        'min_phrase_frequency': 5,  # минимальная частота фразы (остальные игнорируем)
        'min_query_frequency': 0,   # минимальная частота фразы на поиске (остальные игнорируем)
        'min_time_generation_hour': 24*5,  # как часто перегенирируем
        'cdict_generations_yt_path': '//home/broadmatching/cdict_generation',
        'cdict_out_yt_path': '//home/broadmatching/work/cdict'
    }


# создаёт пустую динтаблицу, по мотивам create_dyntable от breqwas@
def create_dynamic_table(name, schema, tablet_cell_bundle, yt_client=yt):
    attributes = {
        'dynamic': True,
        'schema': schema,
        'primary_medium': 'ssd_blobs',
        'tablet_cell_bundle': tablet_cell_bundle,
    }
    yt_client.create('table', path=name, attributes=attributes)
    yt_client.mount_table(name, sync=True)


class CacheMode(enum.Enum):
    NEVER = 0
    TRY = 1
    ALWAYS = 2


class DynTableCache:

    # класс для работы с кэшом в динтаблице; параметры:
    #   table             -   путь к динтаблице, должны быть поля: key_field, value_fields, update_time
    #   key_field         -   поле-ключ динтаблицы (может быть ещё farm_hash-ключ, его можно не указывать)
    #   value_fields      -   поля, где хранится результат
    #   check_fields      -   поля для проверки в ttl_null (default: value_fields)
    #   ttl               -   ttl (в секундах)
    #   ttl_null          -   ttl для записей, в которых есть null
    #   lookup_client     -   subj
    #   map_reduce_client -   subj
    def __init__(
        self,
        table,
        key_field,
        value_fields,
        check_fields=None,
        ttl=None,
        ttl_null=None,
        lookup_client=yt,
        map_reduce_client=yt
    ):
        self.table = table
        self.key_field = key_field
        self.value_fields = value_fields
        if check_fields is None:
            check_fields = value_fields

        self.lookup_client = lookup_client
        self.map_reduce_client = map_reduce_client

        self.read_pack_size = 1000
        self.write_pack_size = 1000

        # make check regular function instead of method to use it in mapper
        def check_func(row):
            check_time = time.time()
            update_time = row['update_time']
            if ttl is not None:
                if check_time - update_time > ttl:
                    return False
            if ttl_null is not None:
                if any(row[f] is None for f in check_fields) and check_time - update_time > ttl_null:
                    return False
            return True

        self.check_func = check_func

    def read(self, keys):
        result = {}
        for keys_pack in gen_packs(keys, pack_size=self.read_pack_size):
            logging.info('cache read %d items', len(keys_pack))
            keys_rows = [{self.key_field: key} for key in keys_pack]
            for row in self.lookup_client.lookup_rows(self.table, keys_rows):
                if self.check_func(row):
                    result[row[self.key_field]] = {f: row[f] for f in self.value_fields}
        logging.info('cache read done!')
        return result

    # записать dict {key: {"v1": val1, "v2": val2}}
    def write(self, data):
        rows = []
        for k, v in data.items():
            row = v.copy()
            row[self.key_field] = k
            row['update_time'] = int(time.time())
            rows.append(row)

        for rows_pack in gen_packs(rows, pack_size=self.write_pack_size):
            logging.info('cache write %d items', len(rows_pack))
            self.lookup_client.insert_rows(self.table, rows_pack)

        logging.info('cache write done!')

    # function: [arg1,arg2,...] -> {key1: {..values..}, ...}
    # key = get_key(arg), used as key in cache; default: key=arg
    # enrich results with cache
    def call(self, function, args, get_key=lambda arg: arg, read_cache=CacheMode.ALWAYS, write_cache=CacheMode.ALWAYS):
        cache = {}
        todo = args

        if not (isinstance(read_cache, CacheMode) and isinstance(write_cache, CacheMode)):
            logging.error('read_cache/write_cache must be CacheMode instance')
            raise ValueError('read_cache/write_cache must be CacheMode instance')

        if read_cache != CacheMode.NEVER:
            keys = set(get_key(x) for x in args)
            try:
                cache = self.read(keys)
            except Exception as e:  # cache may be unavailable
                if read_cache == CacheMode.ALWAYS:
                    raise e
                elif read_cache == CacheMode.TRY:
                    logging.error('Exception at cache read: %s', e)
                    cache = {}

            todo = [x for x in todo if get_key(x) not in cache]

        data = function(todo)

        if write_cache != CacheMode.NEVER:
            try:
                self.write(data)
            except Exception as e:
                if write_cache == CacheMode.ALWAYS:
                    raise e
                elif write_cache == CacheMode.TRY:
                    logging.error('Exception at cache write: %s', e)

        data.update(cache)
        return data

    # оставляем только поля name, type, required
    def get_cache_columns(self):
        schema = self.lookup_client.get_attribute(self.table, 'schema')
        return [{k: col.get(k) for k in ['name', 'type', 'required']} for col in schema]

    def get_key_type(self):
        col_yt_type = {col['name']: col['type'] for col in self.get_cache_columns()}
        return convert_yt_to_py(col_yt_type[self.key_field], use_yson=True)

    def get_extended_schema(self, input_table, add_key=False, add_values=False):
        input_schema = self.map_reduce_client.get_attribute(input_table, 'schema')
        for col in input_schema:
            col.pop('sort_order', None)  # we do not save sort order in hit/miss tables
        add_col_names = set()
        if add_key:
            add_col_names.add(self.key_field)
        if add_values:
            add_col_names.update(self.value_fields)
        add_col_names -= set(col['name'] for col in input_schema)
        add_columns = [col for col in self.get_cache_columns() if col['name'] in add_col_names]
        input_schema += add_columns
        return input_schema

    def get_hit_schema(self, input_table):
        return self.get_extended_schema(input_table, add_key=True, add_values=True)

    def get_miss_schema(self, input_table):
        return self.get_extended_schema(input_table, add_key=True, add_values=False)

    # use cache as a static table to read from it (i.e., with MapReduce operations)
    # arg = get_arg(row), key = get_key(arg), by default just take row[key_fld]
    def read_mr(self, input_table, hit_table, miss_table, get_arg=None, get_key=lambda arg: arg):
        check_func = self.check_func
        ti_fld = '_table_index'
        key_fld = self.key_field
        val_flds = self.value_fields
        if get_arg is None:
            get_arg = lambda row: row[key_fld]
        key_type = self.get_key_type()

        def preprocess(row):
            ti = row[ti_fld] = row.pop('@table_index')
            if ti == 0:  # cache
                if check_func(row):
                    yield row
            else:
                if key_fld not in row:
                    row[key_fld] = key_type(get_key(get_arg(row)))
                yield row

        def join_with_cache(group_key, rows):
            cache = None
            for row in rows:
                if row.pop(ti_fld) == 0:
                    cache = row
                    continue
                if cache is None:
                    # cache miss; input row (+key)
                    row['@table_index'] = 1
                    yield row
                else:
                    # cache hit; input row + values (+key)
                    row['@table_index'] = 0
                    for f in val_flds:
                        row[f] = cache[f]
                    yield row

        self.map_reduce_client.create('table', hit_table, ignore_existing=True)
        self.map_reduce_client.alter_table(hit_table, schema=self.get_hit_schema(input_table))

        self.map_reduce_client.create('table', miss_table, ignore_existing=True)
        self.map_reduce_client.alter_table(miss_table, schema=self.get_miss_schema(input_table))

        self.map_reduce_client.run_map_reduce(
            preprocess,
            join_with_cache,
            [self.table, input_table],
            [hit_table, miss_table],
            reduce_by=[key_fld],
            sort_by=[key_fld, ti_fld],
            map_input_format=yt.YsonFormat(control_attributes_mode='row_fields'),
            reduce_output_format=yt.YsonFormat(control_attributes_mode='row_fields'),
        )

    # костыль на случай падений при freeze/unfreeze
    def ensure_mounted(self):
        tablet_state = self.lookup_client.get_attribute(self.table, 'tablet_state')
        if tablet_state == 'frozen':
            self.lookup_client.unfreeze_table(self.table, sync=True)
        elif tablet_state == 'unmounted':
            self.lookup_client.mount_table(self.table, sync=True)

    # сначала берём mr кэш, потом для miss вызываем функцию, разбивая на пачки
    # get_arg: row -> arg (default: arg=row[key_field])
    # get_key: arg -> key (default: key=arg)
    def call_for_table(self, function, input_table, output_table, call_pack_size, get_arg=None, get_key=lambda arg: arg, stats_hook=None):
        key_fld = self.key_field
        stats = Counter()
        write_output = True
        if output_table is None:
            write_output = False
        if get_arg is None:
            get_arg = lambda row: row[key_fld]
        with self.map_reduce_client.TempTable() as hit_table,\
                self.map_reduce_client.TempTable() as miss_table,\
                self.map_reduce_client.TempTable() as func_data_table:

            self.read_mr(input_table, hit_table, miss_table, get_arg=get_arg, get_key=get_key)
            stats['cache_hit'] = self.map_reduce_client.row_count(hit_table)
            stats['cache_miss'] = total_miss = self.map_reduce_client.row_count(miss_table)
            logging.info('cache stats: %s', stats)

            self.map_reduce_client.alter_table(func_data_table, schema=self.get_hit_schema(input_table))
            for pos in range(0, total_miss, call_pack_size):
                table_view = self.map_reduce_client.TablePath(miss_table, start_index=pos, end_index=pos + call_pack_size)
                key2rows = {}  # чтобы сохранить входную запись
                args = []
                logging.info('call_for_table: read pack ...')
                for row in self.map_reduce_client.read_table(table_view, unordered=True, enable_read_parallel=True):
                    arg = get_arg(row)
                    key = get_key(arg)
                    args.append(arg)
                    key2rows.setdefault(key, []).append(row)
                stats['function_input'] += len(args)
                data_rows = []
                logging.info('call_for_table: call function ...')
                data = function(args)
                stats['function_output'] += len(data.keys())
                if stats_hook is not None:
                    stats_hook(stats)
                if write_output:
                    for k, v in data.items():
                        for row in key2rows[k]:
                            row[key_fld] = k
                            for f in self.value_fields:
                                row[f] = v[f]
                            data_rows.append(row)
                    self.map_reduce_client.write_table(self.map_reduce_client.TablePath(func_data_table, append=True), data_rows)
                logging.info('processed %d of %d', min(pos + call_pack_size, total_miss), total_miss)
            if write_output:
                self.map_reduce_client.run_merge([hit_table, func_data_table], output_table)

        return stats
