import collections
import contextlib
import datetime
import functools
import heapq
import logging
import operator
import threading
import time
import timeit

import psycopg2
import retry

import django.db
from django.db import transaction

from cars.core.solomon import SolomonHelper as SolomonHelperBase
from cars.core.util import datetime_helper, make_yt_client
from cars.core.util import collection_to_mapping  # moved to core.util


LOGGER = logging.getLogger(__name__)


def merge_two_lists(left, right, *, key=None, reversed_source=False, reverse_result=False):
    key = key or (lambda x: x)

    result = []
    _compare = operator.gt if not reversed_source else operator.lt

    while left and right:
        c = left if _compare(key(left[-1]), key(right[-1])) else right
        result.append(c.pop())

    result.extend(reversed(left or right))

    result = list(reversed(result)) if not reverse_result else result

    return result


def merge_lists(*lists, key=None, reversed_source=False, reverse_result=False):
    lists = [(len(l), l) for l in lists]
    heapq.heapify(lists)

    while len(lists) > 1:
        left, right = heapq.heappop(lists), heapq.heappop(lists)
        result = merge_two_lists(left, right, key=key, reversed_source=reversed_source, reverse_result=reverse_result)
        heapq.heappush(lists, (len(result), result))

    merged = lists[0][1] if lists else []
    return merged


def atomic_with_retries(tries=3, delay=0.001, backoff=2):
    if isinstance(tries, collections.Callable):
        return atomic_with_retries()(tries)

    def wrapped(f):
        db_exceptions = (
            django.db.DatabaseError,
            django.db.OperationalError,
            psycopg2.DatabaseError,
            psycopg2.OperationalError,
        )

        # try to re-apply transaction if lock fails / PgBouncer connection errors
        @retry.retry(exceptions=db_exceptions, tries=tries, delay=delay, backoff=backoff)
        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            try:
                with transaction.atomic(savepoint=False):
                    return f(*args, **kwargs)
            except db_exceptions:
                django.db.close_old_connections()
                raise

        return wrapper

    return wrapped


class RateLimiter(object):
    def __init__(self, limit):
        self._times = collections.deque(maxlen=limit)
        self._lock = threading.RLock()

    def sync(self):
        time_to_sleep = 0

        with self._lock:
            if len(self._times) == self._times.maxlen:  # full
                time_delta = time.time() - self._times[0]
                while time_delta > 1 and self._times:
                    self._times.popleft()

                if len(self._times) == self._times.maxlen:  # full anyway
                    time_to_sleep = 1 - time_delta
                    self._times.popleft()

            self._times.append(time.time() + time_to_sleep)

        return time_to_sleep

    def wait(self):
        time_to_sleep = self.sync()
        if time_to_sleep > 0:
            time.sleep(time_to_sleep)


@contextlib.contextmanager
def timer():
    s = timeit.default_timer()
    try:
        yield
    finally:
        e = timeit.default_timer()
        print('total time: {}s'.format(e - s))


class SolomonHelper(SolomonHelperBase):
    def __init__(self, sensor_prefix):
        super().__init__('request_aggregator', sensor_prefix)


class SolomonReporter(SolomonHelper):
    DEFAULT_SENSOR_NAME = 'processed_entries_count'

    def __init__(self, sensor_prefix, default_sensor_name=None):
        super().__init__(sensor_prefix)
        self._default_sensor_name = default_sensor_name or self.DEFAULT_SENSOR_NAME

    def __call__(self, f):
        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            result = f(*args, **kwargs)
            if isinstance(result, collections.Mapping):
                self.increment_counters(**result)
            elif isinstance(result, int):
                self.increment_counter(self._default_sensor_name, result)
            else:
                pass
            return result

        return wrapper


class YTHelper(object):
    def __init__(self, table_path):
        self._yt = make_yt_client('data')
        self._table_path = table_path

    @property
    def table_path(self):
        return self._table_path

    def temp_table_path(self):
        timestamp = str(int(time.time() * 1000000))
        return '{}_temp_{}'.format(self.table_path, timestamp)

    def _get_property(self, *path_parts):
        assert len(path_parts) >= 1
        base = self._yt.config
        for part in path_parts:
            base = base[part]
        return base

    def _set_property(self, new_value, *path_parts):
        assert len(path_parts) >= 1

        base = self._yt.config
        for part in path_parts[:-1]:
            base = base[part]

        old_value = base[path_parts[-1]]
        base[path_parts[-1]] = new_value

        return old_value

    def get_table_row_count(self):
        if self._yt.exists(self._table_path):
            row_count = self._yt.row_count(self._table_path)
        else:
            row_count = None
        return row_count

    def check_table_data_exists(self, specific_date, date_column_name='date'):
        row_count = self.get_table_row_count()

        if row_count:
            table_path_wrap = self._yt.TablePath(
                self._table_path,
                columns=[date_column_name],
                exact_index=row_count - 1
            )

            last_entry = next(self._yt.read_table(table_path_wrap, format='json'))
            last_date_str = last_entry[date_column_name]
            last_date = datetime_helper.localize(
                datetime.datetime.strptime(last_date_str, "%Y-%m-%d")
            ).date()

            return last_date >= specific_date

        return False

    def import_data(self, start_index, end_index, is_parallel=True):
        curr_value = self._set_property(is_parallel, "read_parallel", "enable")

        table_path_wrap = self._yt.TablePath(
            self._table_path,
            start_index=start_index,
            end_index=end_index,
        )

        try:
            for row in self._yt.read_table(table_path_wrap, format='json'):
                yield row
        finally:
            self._set_property(curr_value, "read_parallel", "enable")

    def export_data(self, data_to_store, *, append=True, raw=None, schema=None):
        with self._yt.Transaction(timeout=50 * 60 * 1000, ping_ancestor_transactions=True):
            kwargs = {'attributes': {'schema': schema}} if schema is not None else {}
            self._yt.create("table", self._table_path, recursive=True, ignore_existing=True, **kwargs)

            table_path_wrap = self._yt.TablePath(self._table_path, append=append)
            self._yt.write_table(table_path_wrap, data_to_store, raw=raw)

    def merge_chunks(self):
        self._yt.run_merge(
            self._table_path,
            self._table_path,
            spec={'combine_chunks': True}
        )
