import yt.yson as yson

import os
import time
import numpy as np
from functools import wraps

import logging
logger = logging.getLogger(__name__)


def with_no_exception(func):
    @wraps(func)
    def wrap(*args, **kwargs):
        result = None
        try:
            result = func(*args, **kwargs)
        except Exception:
            logger.exception("{} is failed".format(func.__name__))
        return result
    return wrap


class with_preparation(object):
    prepare_key = "with_preparing"

    def __init__(self, default):
        self.default = default

    def __call__(self, func):
        @wraps(func)
        def wrapper(other, *args, **kwargs):
            if self.prepare_key in kwargs:
                should_prepare = kwargs.get(self.prepare_key)
            else:
                should_prepare = self.default

            if should_prepare:
                other.prepare_table()

            if self.prepare_key in kwargs:
                del kwargs[self.prepare_key]

            return func(other, *args, **kwargs)
        return wrapper


class Tablet(object):
    STATE_FIELD = 'state'

    class TabletStates(object):
        MOUNTED = 'mounted'
        FROZEN = 'frozen'

    @staticmethod
    def get_finished_state_checker():
        return (lambda tablet: not tablet[Tablet.STATE_FIELD].endswith('ing'))


class AbstractDynamicStorage(object):
    @property
    def path(self):
        return self._path

    @property
    def schema(self):
        return self._schema

    def __init__(self, yt, path=None, schema=None, experiment_mode=False, **kwargs):
        self.yt = yt
        self._path = path
        self._schema = schema
        if experiment_mode:
            for func in [self.insert_rows, self.reshard_table,  self.prepare_table, self.create]:
                self.__dict__[func.__name__] = with_no_exception(func)

    def create(self, n_tablets=15, optimize_for=None, in_memory_mode=None, upstream_replica_id=None):
        if os.environ.get("YT_USE_SINGLE_TABLET", ""):
            n_tablets = 1

        path = self.path
        attributes = dict(
            schema=self.schema,
            dynamic=True
        )
        if upstream_replica_id is not None:
            attributes["upstream_replica_id"] = upstream_replica_id

        self.yt.remove(path, force=True)
        self.yt.create(
            'table',
            path,
            attributes=attributes,
            recursive=True,
            ignore_existing=True,
        )
        self.yt.alter_table(path, dynamic=True)
        if self.tablet_state == Tablet.TabletStates.MOUNTED:
            self.yt.unmount_table(path, sync=True)
        self.reshard_table(n_tablets)
        if optimize_for:
            self.yt.set_attribute(self.path, 'optimize_for', optimize_for)
        if in_memory_mode:
            self.yt.set_attribute(self.path, 'in_memory_mode', in_memory_mode)
        self.mount_table(sleep_time=1, timeout=10)

    def reshard_table(self, n_tablets):
        pivot_keys = [[]] + [
            [yson.YsonUint64((np.uint(-1) - 1) / n_tablets * i)] for i in range(1, n_tablets)]
        self.yt.reshard_table(self.path, pivot_keys=pivot_keys)

    def mount_table(self, **kwargs):
        if self._get_state_checker(Tablet.TabletStates.MOUNTED)():
            return True
        self._wait_state(tablet_checker=Tablet.get_finished_state_checker())
        if self._get_state_checker(Tablet.TabletStates.FROZEN)(any):
            self._wait_state(tablet_checker=Tablet.get_finished_state_checker())
            self.yt.unfreeze_table(self.path, sync=True)
            self._wait_state(tablet_checker=lambda tablet: tablet[Tablet.STATE_FIELD] != Tablet.TabletStates.FROZEN)
            self._wait_state(tablet_checker=Tablet.get_finished_state_checker())
        if self._get_state_checker(Tablet.TabletStates.MOUNTED)():
            return True
        self.yt.mount_table(self.path, sync=True)

        return self._wait_state(Tablet.TabletStates.MOUNTED)

    @property
    def tablet_states(self):
        return list(self.yt.get('{}/@tablets'.format(self.path))) + [{Tablet.STATE_FIELD: self.tablet_state}]

    @property
    def tablet_state(self):
        return str(self.yt.get('{}/@tablet_state'.format(self.path)))

    def _get_state_checker(self, state=None, tablet_checker=None):
        tablet_checker = tablet_checker or (lambda tablet: tablet[Tablet.STATE_FIELD] == state)

        def checker(aggr_func=None):
            aggr_func = aggr_func or all
            return aggr_func(map(tablet_checker, self.tablet_states))

        return checker

    def _wait_state(self, state=None, tablet_checker=None, sleep_time=0.2, timeout=4.):
        is_ready = self._get_state_checker(state=state, tablet_checker=tablet_checker)
        sleep_time_sum = 0

        while not is_ready() and sleep_time_sum < timeout:
            sleep_time_sum += sleep_time
            time.sleep(sleep_time)

        return is_ready()

    @with_preparation(True)
    def insert_rows(self, rows, **kwargs):
        self.yt.insert_rows(self.path, rows, **kwargs)

    @with_preparation(True)
    def select_rows(self, what=None, where=None, **kwargs):
        query = "{} FROM [{}]".format(what or "*", self.path)
        if where is not None:
            query += " WHERE {}".format(where)
        return self.yt.select_rows(query)

    @with_preparation(True)
    def delete_rows(self, rows):
        return self.yt.delete_rows(self.path, rows)

    def prepare_table(self, sleep_time=None, **kwargs):
        if not self.yt.exists(self.path):
            self.create()
        self.mount_table()
        if sleep_time:
            time.sleep(sleep_time)
        return self

    @property
    def batched_inserter(self):
        return self.get_batched_inserter()

    def get_batched_inserter(self, *args, **kwargs):
        return BatchedInserter(self, *args, **kwargs)

    @staticmethod
    def record(record, *args, **kwargs):
        return record


class BatchedInserter(object):
    def __init__(self, storage, max_batch_size=100000):
        self.storage = storage
        self.max_batch_size = max_batch_size
        self.rows = []

    def __enter__(self):
        self.storage.prepare_table(sleep_time=0.2)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None:
            self.push_rows()

    def push_rows(self):
        if len(self.rows):
            self.storage.insert_rows(self.rows)
            self.rows = []

    def insert_rows(self, rows):
        self.rows.extend(rows)
        if len(self.rows) > self.max_batch_size:
            self.push_rows()

    def insert_row(self, *args, **kwargs):
        record = self.storage.record(*args, **kwargs)
        self.insert_rows([record])
