# -*- coding: utf-8 -*-
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=too-few-public-methods
from __future__ import unicode_literals
from itertools import islice
import contextlib
import logging
import json
import re
import os
import requests
import six
import functools
from chapi.errors import ProgrammingError

logger = logging.getLogger('clickhouse_client')


@contextlib.contextmanager
def dump_request_error():
    try:
        yield
    except requests.exceptions.RequestException as e:
        pattern = '{r.method} request {r.url} failed with error:\n{e}\n'\
                  'request body: {r.body}\n'\
                  'request headers: {headers}'
        headers = dict(e.request.headers)
        status_str = pattern.format(e=e, r=e.request, headers=headers)
        if e.response is not None:
            pattern = u'response status {r.status_code}({r.reason})\n'\
                      u'{r.content}'
            response_status = pattern.format(r=e.response)
            response_status = re.sub(r'password=[^&]+',
                                     'password=***', response_status)
            status_str = u'{}\n{}'.format(status_str, response_status)
        logger.error(status_str)
        raise


def _defaults(json_path):
    """
    .defaults example
    just valid json file

    cat ~/.chapi/.defaults
    {"username": "__your__username__",
     "password": "__your__password__",
     "host": "mtstat.yandex.ru",
     "port": 8123,
     "connect_timeout": 60,
     "read_timeout": 60
    }
    """
    path = os.path.expanduser('~/.chapi/.defaults')
    if json_path:
        path = json_path
    try:
        with open(path) as f:
            defaults = json.load(f)
    except IOError:
        if isinstance(json_path, six.string_types):
            try:
                defaults = json.loads(json_path)
            except ValueError:
                defaults = dict()
        else:
            defaults = dict()
    return defaults


def take(n, iterable):
    """Return first n items of the iterable as a list"""
    return list(islice(iterable, n))


def get_iter_length(iterator):
    try:
        return len(iterator)
    except TypeError:
        return None


def integers():
    i = 0
    while True:
        yield i
        i += 1


def safefetch(fetch):
    @functools.wraps(fetch)
    def wrapfetch(*args, **kwargs):
        try:
            return fetch(*args, **kwargs)
        except (StopIteration, IndexError):
            return None
    return wrapfetch


def plain(dic):
    return list(dic.values())


class ProgIter(object):
    def __init__(self, iterable, total, trace, jupyter):
        self.trace = trace
        self.iterable = iterable
        self.total = total
        self.jupyter = jupyter

    def __iter__(self):
        if self.trace:
            try:
                import tqdm
                if self.jupyter:
                    return iter(tqdm.tqdm_notebook(
                            iter(self.iterable),
                            total=self.total, leave=True))
                else:
                    return iter(tqdm.tqdm(
                            iter(self.iterable),
                            total=self.total, leave=True))
            except ImportError:
                return iter(self.iterable)
        else:
            return iter(self.iterable)


# noinspection SqlNoDataSourceInspection
class TmpTable(object):
    """If you want to use tmp tables, use this thing
    Usage:
    ------
    >>  with TmpTable(conn, HEAVY_QUERY, 'mydb') as tablename:
    ...     # do your staff
    """
    tmp_create = """
    CREATE TABLE {table} ENGINE = Log AS
    {select}
    """
    db_create = """
    CREATE DATABASE IF NOT EXISTS {db}
    """
    tmp_drop = """
    DROP TABLE IF EXISTS {table}
    """

    def __init__(self, conn, select, db=None, name=None):
        """if you want to create random self.name leave name=None
        """
        if name is None:
            import hashlib
            name = hashlib.sha224(select.encode('utf-8')).hexdigest()
        if db:
            self.table = '%s.tmp_%s' % (db, name)
        else:
            self.table = 'tmp_auto_%d_%s' % (id(self), name)
        self.conn = conn
        self.prepare = self.db_create.format(db=db) if db else None
        self.destruction = self.tmp_drop.format(table=self.table)
        self.creation = self.tmp_create.format(table=self.table, select=select)

    def __enter__(self):
        # ensure db exists
        if self.prepare:
            self.conn.request(self.prepare, _output=False)
        # ensure db.table does not exist
        self.conn.request(self.destruction, _output=False)
        # do your staff
        self.conn.request(self.creation, _output=False)

        return self.table

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.conn.request(self.destruction, _output=False)


def date_range(start, end):
    """
    returns date range beginning from start end ending day before end
    >>> date_range('2016-07-07', '2016-07-07')
    [datetime.date(2016, 7, 7)]
    >>> date_range('2016-07-07', '2016-07-08')
    [datetime.date(2016, 7, 7), datetime.date(2016, 7, 8)]
    >>> date_range('2016-07-07', '2016-07-06')
    Traceback (most recent call last):
    ...
    TypeError: start > end


    :type start: str | datetime.date
    :type end: str | datetime.date
    :return: list of dates in the period
    """
    import datetime
    if isinstance(start, (six.string_types, six.text_type)):
        start = datetime.datetime.strptime(start, '%Y-%m-%d').date()
    if isinstance(end, (six.string_types, six.text_type)):
        end = datetime.datetime.strptime(end, '%Y-%m-%d').date()
    if not (isinstance(start, datetime.date) and
            isinstance(end, datetime.date)):
        raise TypeError('start and end must be string like '
                        '%Y-%m-%d or instances of datetime.date')
    if start == end:
        return [start]
    elif start < end:
        date_generated = [start + datetime.timedelta(days=x)
                          for x in range(0, (end - start).days + 1)]
    else:
        raise TypeError('start > end')
    return date_generated


class Delayer(object):
    strategies = ['lin', 'exp', 'poly', 'const']

    @staticmethod
    def const(n, c):
        for _ in range(n):
            yield c

    @staticmethod
    def lin(n, c):
        for i in range(n):
            yield i * c

    @staticmethod
    def exp(n, c):
        for i in range(n):
            yield c ** i

    @staticmethod
    def poly(n, c):
        for i in range(n):
            yield i ** c

    def __init__(self, n, c=2, strategy='lin'):
        if strategy not in self.strategies:
            raise ProgrammingError('strategy %s not in %s' %
                                   (strategy, self.strategies))
        self.wait_gen = lambda: self.__getattribute__(strategy)(n, c)

    def __iter__(self):
        return self.wait_gen()
