# -*- coding: utf-8 -*-

import certifi
import requests
import memcache
import settings
import logging
from base64 import b64decode, b64encode
import re
import json
from . import escape_string

cafile = certifi.where()
with open(settings.CERT_FILE, 'rb') as infile:
    customca = infile.read()
with open(cafile, 'ab') as outfile:
    outfile.write(customca)

class ClickhouseException(Exception):
    pass


class ClickhouseClient(object):

    def __init__(self, url='https://{}:{}'.format(settings.CLICKHOUSE_HOST, settings.CLICKHOUSE_PORT),
                 connection_timeout=1500,
                 readonly=False):

        self.connection_timeout = connection_timeout
        self.url = url
        self.readonly = readonly
        self.session = requests.Session()
        self.session.verify = True
        self.session.headers.update(
            {'X-ClickHouse-User': settings.CLICKHOUSE_USER, 'X-ClickHouse-Key': settings.CLICKHOUSE_PWD}
        )
        if settings.env_type == 'production':
            self.session.headers.update({'Host': self.url.split(':')[1].lstrip('//')})

    def execute(self, query, query_params=None):
        if self.readonly and not query.lstrip()[:10].lower().startswith('select'):
            raise ClickhouseException('Clickhouse client is in readonly state')
        query = self._prepare_query(query, query_params=query_params)
        r = self._post(query)
        if r.status_code != 200:
            raise ClickhouseException('%s\n%s\n=====\n%s\n====' % (r.status_code, r.content.rstrip('\n'), query))

    def insert(self, query, query_params=None):
        if self.readonly:
            raise ClickhouseException('Clickhouse client is in readonly state')
        query = self._prepare_query(query, query_params=query_params)
        assert query.lstrip()[:10].lower().startswith('insert')
        r = self._post(query)
        if r.status_code != 200:
            raise ClickhouseException('%s\n%s' % (r.status_code, r.content.rstrip('\n')[:200]))

    def select(self, query, query_params=None, deserialize=True):
        query = self._prepare_query(query, query_params=query_params)
        query = query.rstrip().rstrip(';') + ' FORMAT JSONCompact'
        if not self.readonly and not self.url.endswith('max_threads=3'):
            self.url += '?joined_subquery_requires_alias=0&max_threads=3'
        assert query.lstrip()[:10].lower().startswith('select')
        r = self._post(query)
        if r.status_code != 200:
            raise ClickhouseException('%s\n%s\n=====\n%s\n====' % (r.status_code, r.content.rstrip('\n'), query))
        if deserialize:
            return json.loads(r.content)['data']
        else:
            return r.content

    def select_tsv(self, query, query_params=None):
        query = self._prepare_query(query, query_params=query_params)
        query = query.rstrip().rstrip(';')
        if not self.readonly and not self.url.endswith('max_threads=3'):
            self.url += '?joined_subquery_requires_alias=0&max_threads=3'
        assert query.lstrip()[:10].lower().startswith('select')
        r = self._post(query)
        if r.status_code != 200:
            raise ClickhouseException('%s\n%s\n=====\n%s\n====' % (r.status_code, r.content.rstrip('\n'), query))
        return r.content

    def _prepare_query(self, query, query_params=None):
        self.query_params_check(query_params)
        if query_params is not None:
            query %= dict(
                (key,
                 unicode(item) if key in (
                     'metric_ids',
                     'job_ids',
                     'cases_with_tag',
                     'tags',
                     'targets',
                 ) else escape_string(unicode(item)))
                for key, item in query_params.iteritems()
            )
        query = query.encode('utf-8')
        return query

    def _post(self, query):
        return self.session.post(
                self.url,
                data=query,
                timeout=10,
                verify=settings.CERT_FILE
        )

    @staticmethod
    def query_params_check(query_params):
        """

        :param query_params: must be dict
        :return: None
        """
        rules = {
            'job': lambda p: str(p).isdigit(),
            'job_date': lambda p: str(p).isdigit(),
            'compress_ratio': lambda p: str(p).isdigit(),
        }
        if query_params is not None:
            if isinstance(query_params, dict):
                for key in query_params.keys():
                    param = query_params[key]
                    if key in rules:
                        try:
                            assert rules[key](param)
                        except AssertionError:
                            raise SanitaryError('param "%s" is invalid: %s' % (key, param))
                    else:
                        try:
                            assert not re.findall('select', str(param)) and not re.findall('remote', str(param))
                        except AssertionError:
                            raise SanitaryError('param "%s" is invalid: %s' % (key, param))
            else:
                raise SanitaryError('query_params must be of dict type')


class SanitaryError(Exception):
    pass


class Singleton(type):
    instances = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls.instances:
            cls.instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls.instances[cls]


class MemCache(memcache.Client):

    def __init__(self, servers, fmt='json', **kwargs):
        super(MemCache, self).__init__(servers, **kwargs)
        self.fmt = fmt

    def set(self, key, val, **kwargs):
        try:
            if self.fmt == 'json':
                val = json.dumps(val)
            elif self.fmt == 'b64':
                val = b64encode(val)
            return super(MemCache, self).set(str(key), val)
        except TypeError as exc:
            logging.error(exc)

    def get(self, key):
        val = super(MemCache, self).get(str(key))
        if val is None:
            return
        try:
            if self.fmt == 'json':
                val = json.loads(val)
            elif self.fmt == 'b64':
                val = b64decode(val)
            return val
        except TypeError as exc:
            logging.error(exc)

    def delete(self, key, **kwargs):
        super(MemCache, self).delete(str(key))


class CacheMock(object):
    def set(self, *args, **kwargs):
        return None

    def get(self, *args, **kwargs):
        return None

    def delete(self, *args, **kwargs):
        return None


class CacheClient(object):
    def __new__(cls, expire=2592000, fmt='json'):  # a month
        """

        :param expire:
        :return: CaaS client
        """
        # instance = MemCache(['{}:{}'.format(settings.MEMCACHE_HOST, settings.MEMCACHE_PORT)])
        instance = CacheMock()
        return instance
