# -*- encoding: utf-8 -*-
"""
Сборник небольших функций для работы с YT, специфичных для Директа
"""

from __future__ import absolute_import

import dateutil.parser
import datetime
import logging
import os
import sys
from itertools import izip

from MySQLdb import FIELD_TYPE

from yt import yson
import yt.wrapper as yt
from yt.wrapper.string_iter_io import StringIterIO
from yt.wrapper.format import DsvFormat, YsonFormat

import settings
from yandex.utils import tmp_attr
import ytutils_mod_filter

from direct.tracing import Trace

DEFAULT_FORMAT = yt.DsvFormat()
DEFAULT_COMPRESSION_CODEC = 'brotli_8'
DEFAULT_ERASURE_CODEC = 'lrc_12_2_2'
DEFAULT_DESIRED_CHUNK_SIZE = 8 * 1024 ** 3
CODEC_ALIASES = {"brotli8": "brotli_8"}
YT_PREFIX = os.environ['YT_PREFIX']
YT_TMP_PREFIX = YT_PREFIX + 'tmp/'


def yt_schema(strict, columns):
    schema = yson.YsonList()
    schema.attributes["strict"] = strict
    for col in columns:
        if isinstance(col, dict):
            schema.append(col)
        elif isinstance(col, (list, tuple)) and len(col) == 2:
            schema.append({'name': col[0], 'type': col[1]})
    return schema


def check_create_yt_tbl(yt_tbl, cleanup=False, sort_by=None, channels=None, compressed=False, schema=None, optimize_for=None, expiration_time=None):
    """ создание таблицы (возможно отсортированной) в YT
    """
    if not yt.exists(yt_tbl):
        attrs = {}
        if schema is not None:
            attrs['schema'] = schema
        if optimize_for is not None:
            if optimize_for not in ('lookup', 'scan'):
                raise Exception("Incorrect optimize_for")
            else:
                attrs['optimize_for'] = optimize_for
        if expiration_time is not None:
            attrs['expiration_time'] = expiration_time
        yt.create("table", yt_tbl, recursive=True
                        , attributes=attrs
                        )

    if compressed:
        yt.set(yt_tbl + "/@compression_codec", DEFAULT_COMPRESSION_CODEC)
        yt.set(yt_tbl + "/@erasure_codec", DEFAULT_ERASURE_CODEC)

    if cleanup:
        yt.run_erase(yt_tbl)
    if not yt.is_sorted(yt_tbl) and sort_by:
        yt.run_sort(yt_tbl, sort_by=sort_by)


def norm_codec(codec):
    return CODEC_ALIASES.get(codec, codec)


def compress_yt_table(yt_tbl, force=False,
                      sort_by=None, erasure_codec=DEFAULT_ERASURE_CODEC,
                      compression_codec=DEFAULT_COMPRESSION_CODEC,
                      desired_chunk_size=None,
                      data_size_per_job=None,
                      optimize_for=None):
    """
    Сжатие таблицы, возможно с предварительной сортировкой, для большей эффективности
    https://wiki.yandex-team.ru/yt/userdoc/erasure
    """

    if optimize_for is not None and optimize_for not in ('lookup', 'scan'):
        raise Exception("Incorrect optimize_for")

    attrs = yt.get(yt_tbl + "/@")

    is_well_sorted = sort_by is None or attrs["sorted"] and sort_by == attrs["sorted_by"]

    if not force \
       and norm_codec(attrs["compression_codec"]) == norm_codec(compression_codec) \
       and attrs["erasure_codec"] == erasure_codec \
       and (optimize_for is None or attrs["optimize_for"] == optimize_for) \
       and is_well_sorted:
        logging.debug("table {} already compressed".format(yt_tbl))
        return

    if desired_chunk_size is None:
        desired_chunk_size = DEFAULT_DESIRED_CHUNK_SIZE

    if data_size_per_job is None:
        if attrs["compression_ratio"] > 0:
            data_size_per_job = max(256 * 1024 * 1024, int(desired_chunk_size / attrs["compression_ratio"]))
        else:
            data_size_per_job = desired_chunk_size

    job_count = int(attrs["uncompressed_data_size"] / data_size_per_job)

    with yt.Transaction(ping=True):
        yt.set_attribute(yt_tbl, "compression_codec", compression_codec)
        yt.set_attribute(yt_tbl, "erasure_codec", erasure_codec)
        if optimize_for is not None:
            yt.set_attribute(yt_tbl, "optimize_for", optimize_for)

        if not is_well_sorted:
            logging.warn("Sort %s", yt_tbl)
            yt.run_sort(yt_tbl, sort_by=sort_by)

        logging.warn("Compress %s", yt_tbl)
        yt.run_merge(yt_tbl, yt_tbl,
                     mode="sorted" if attrs['sorted'] or sort_by else "unordered",
                     spec={
                         "combine_chunks": True,
                         "force_transform": True,
                         "data_size_per_job": data_size_per_job,
                         "job_io": {
                             "table_writer": {
                                 "desired_chunk_size": desired_chunk_size
                             }
                         }
                     }
        )


# свой эскейпинг, с сожалению, значительно быстрее встроенной сериализации
_escape_dsv_key_cache = {}
def escape_dsv_key(s):
    if s not in _escape_dsv_key_cache:
        r = s.replace("\\", "\\\\")
        r = r.replace("\n", "\\n")
        r = r.replace("\r", "\\r")
        r = r.replace("\t", "\\t")
        r = r.replace("\0", "\\0")
        r = r.replace("=", "\\=")
        _escape_dsv_key_cache[s] = r
    return _escape_dsv_key_cache[s]


def escape_dsv_val(s):
    s = s.replace("\\", "\\\\")
    s = s.replace("\n", "\\n")
    s = s.replace("\r", "\\r")
    s = s.replace("\t", "\\t")
    s = s.replace("\0", "\\0")
    return s


def cursor_lines_generator(cursor, fmt, limit=None, preprocess=None, fetch_limit=1000):
    """ Из курсора БД делаем генератор записей YT
    fmt - один из форматов YT
    limit - сколько всего строк выдаёт генератор (потом его можно ещё раз сделать)
    fetch_limit - по сколько строк фетчим
    """
    cnt = 0
    cols = [d[0] for d in cursor.description]
    while True:
        if limit and cnt >= limit:
            break
        if limit and fetch_limit > limit - cnt:
            fetch_limit = limit - cnt
        rows = [dict(izip(cols, r)) for r in cursor.fetchmany(fetch_limit)]
        cnt += len(rows)
        if not rows:
            cursor.close()
            break
        if preprocess is not None:
            for row in rows:
                preprocess(row)
        if isinstance(fmt, YsonFormat):
            yson_format = fmt.attributes.get('format', 'binary')
            # dsv не поддерживает null, удаляем такие поля
            for row in rows:
                for key, val in row.items():
                    if val is None:
                        del row[key]
            yield yson.dumps(rows, yson_format, 'list_fragment')
        elif isinstance(fmt, DsvFormat):
            # fast dsv serialization
            for row in rows:
                pairs = []
                for key, val in row.items():
                    _key = escape_dsv_key(key)
                    if val == None:
                        continue
                    elif isinstance(val, datetime.datetime):
                         _val = val.strftime("%Y%m%d%H%M%S")
                    elif isinstance(val, int) or isinstance(val, float):
                         _val = str(val)
                    elif isinstance(val, unicode):
                        _val = escape_dsv_val(val.encode("utf8"))
                    else:
                        _val = escape_dsv_val(str(val))
                    pairs.append(_key + "=" + _val)
                if pairs:
                    yield "\t".join(pairs) + "\n"
        else:
            for row in rows:
                for key, val in row.items():
                    if val == None:
                        del row[key]
                    elif isinstance(val, datetime.datetime):
                        row[key] = val.strftime("%Y%m%d%H%M%S")
                    elif isinstance(val, unicode):
                        row[key] = val.encode("utf8")
                yield fmt.dumps_row(row)


def save_cursor_to_yt(yt_tbl, cursor, fmt=None, cleanup=False, channels=None, preprocess=None, chunk_size=100000, schema=None, optimize_for=None):
    """ Загружает курсор в таблицу YT
    """
    check_create_yt_tbl(yt_tbl, cleanup=cleanup, channels=channels, schema=schema, optimize_for=optimize_for)
    if fmt is None:
        fmt = DEFAULT_FORMAT
    # save by chunks
    while cursor.connection:
        logging.info(" append data to %s" % yt_tbl)
        gen = cursor_lines_generator(cursor, fmt, limit=chunk_size, preprocess=preprocess)
        if not isinstance(fmt, YsonFormat):
            gen = StringIterIO(gen)
        with Trace.current().profile("yt:write", tags=unicode(os.environ.get('YT_PROXY', ''))):
            yt.write_table(yt.TablePath(yt_tbl, append=True), gen, fmt, raw=True)


def close_yt_connections():
    # finish all keep-alive connections (including parent)
    # useful for forks
    try:
        yt.http.get_session().get_adapter('http://').poolmanager.pools.clear()
    except Exception, e:
        pass


def tune_db_cursor(cursor):
    """
    yson не умеет работать с datetime и decimal - отключаем их конвертацию
    """
    conv = cursor.connection.converter
    for dt in [
            FIELD_TYPE.DATE, FIELD_TYPE.DATETIME, FIELD_TYPE.TIMESTAMP,
            FIELD_TYPE.ENUM, FIELD_TYPE.SET,
            FIELD_TYPE.FLOAT, FIELD_TYPE.DOUBLE,
            FIELD_TYPE.DECIMAL, FIELD_TYPE.NEWDECIMAL,
    ]:
        if dt in conv:
            del conv[dt]
    # yson не умеет unsigned int64, только signed
    # заменяем на int - будеть падать на числах >=2**63
    # для заливки 64-х битных хэшей их нужно преобразовывать а str
    conv[FIELD_TYPE.LONG] = int
    conv[FIELD_TYPE.LONGLONG] = int


def run_simple_operation(op, *args, **kwargs):
    """
    запуск операции в YT с запретом упаковки __main__ и всех модулей, кроме yt*
    аналог @yt.simple, но должен с большим успехом переживать обновления yandex-yt-python на кластерах
    """

    with tmp_attr(sys.modules['__main__'], '__file__', '/dev/null'):
        old_mf = yt.config['pickling']['module_filter']
        yt.config['pickling']['module_filter'] = ytutils_mod_filter.mod_filter
        try:
            ret = op(*args, **kwargs)
        finally:
            yt.config['pickling']['module_filter'] = old_mf

    return ret


def cleanup_yt_dir(path, border_time, recurse=True, force=False):
    list_info = yt.list(path, attributes=['type', 'name', 'access_time', 'modification_time', 'count', 'ignore_direct_auto_cleanup'])
    for node in list_info:
        try:
            if 'ignore_direct_auto_cleanup' in node.attributes:
                logging.warn("skip node %s, because it protected by attribute" % (path + '/' + node))
                continue
            if 'access_time' not in node.attributes:
                # possible no access for node
                logging.warn("skip node %s, because there is no access_time attr" % (path + '/' + node))
                continue
            access_time = dateutil.parser.parse(node.attributes['access_time']).replace(tzinfo=None)
            if node.attributes['type'] == 'map_node':
                if node.attributes['count'] == 0:
                    if dateutil.parser.parse(node.attributes['modification_time']).replace(tzinfo=None) < border_time:
                        if force:
                            logging.warn("remove empty node %s" % (path + '/' + node))
                            yt.remove(path + '/' + node)
                        else:
                            logging.warn("old empty node %s" % (path + '/' + node))
                elif recurse:
                    cleanup_yt_dir(path+'/'+node, border_time, recurse, force)
            elif node.attributes['type'] in {'table', 'file', 'link'} and access_time < border_time:
                if force:
                    logging.warn("remove %s %s" % (node.attributes['type'], path + '/' + node))
                    yt.remove(path + '/' + node)
                else:
                    logging.warn("old %s %s" % (node.attributes['type'], path + '/' + node))
            else:
                #logging.warn("skip: "+node+' '+str(node.attributes))
                pass
        except UserWarning:
            raise
        except Exception, e:
            msg = "Failed to process node " + path + '/' + node
            logging.error(msg, exc_info=True)
            raise UserWarning(msg, e)
