# -*- coding: utf-8 -*-

import requests
import settings
import logging
import re
import json
from urllib.parse import urljoin
from xml.etree import ElementTree
from base64 import b64decode, b64encode
from collections import namedtuple
import redis
from redis.sentinel import Sentinel

from startrek_client import Startrek
from startrek_client.exceptions import NotFound, Forbidden

from . import memcache_v59 as memcache
from .meta import escape_string
from .exceptions import TaskError, SanitaryError, ClickhouseException
from settings import REDIS_HOST, REDIS_PORT, REDIS_SENTINEL_PORT, REDIS_DB_NAME, REDIS_DB_PWD, ENV_TYPE


class StaffClient(object):
    def __init__(self, request=None):
        """
        Doesn't catch any exceptions
        """
        self.request = request
        self.urlbase = "https://staff-api.yandex-team.ru/v3/"
        token = settings.OAUTH_TOKEN
        self.headers = {'Host': 'staff-api.yandex-team.ru',
                        'Authorization': 'OAuth {}'.format(token)
                        }

        self.verify = settings.CERT_FILE
        self.timeout = 2

    def get_user_info(self, user_login, fields=None):
        url = self.urlbase + "persons?login={}".format(user_login)
        if fields:
            url += "&_fields={}".format(",".join(fields))
        resp = requests.get(url, verify=self.verify, timeout=self.timeout, headers=self.headers).json()
        return resp

    def get_department_logins(self, dpt, dismissed=False):
        dismissed = '&official.is_dismissed={}'.format(str(bool(dismissed)).lower()) if dismissed else ''
        assert isinstance(dpt, int)
        url = self.urlbase + 'persons?department_group.id={}{}&_fields=login'.format(dpt, dismissed)
        staff = requests.get(url, verify=self.verify, timeout=self.timeout, headers=self.headers)
        staff.raise_for_status()
        staff = staff.json()
        return [u['login'] for u in staff['result']]

    def get_rfid_user_info(self, rfid):
        # TODO: memcache
        rfid = int(str(rfid)[4:], 16)
        url = 'https://staff.yandex-team.ru/rfid-api/export/{}/'.format(rfid)
        headers = self.headers.copy()
        headers['Host'] = 'staff.yandex-team.ru'
        resp = requests.get(url, verify=self.verify, timeout=self.timeout, headers=headers)
        resp.raise_for_status()
        return json.loads(resp.content.decode('utf-8'))


class ClickhouseClient(object):

    def __init__(self, url=None, connection_timeout=1500,
                 readonly=False):

        self.local = True if settings.ENV_TYPE in ('localhost', 'unittest') else False
        self.connection_timeout = connection_timeout
        if url is not None:
            self.url = url
        elif self.local:
            self.url = 'http://localhost:8123'
        else:
            self.url = 'https://{}:{}'.format(settings.CLICKHOUSE_HOST, settings.CLICKHOUSE_PORT)
        self.readonly = readonly
        self.session = requests.Session()
        self.session.verify = settings.CERT_FILE
        if not self.local:
            self.session.headers.update({
                'X-ClickHouse-User': settings.CLICKHOUSE_USER,
                'X-ClickHouse-Key': settings.CLICKHOUSE_PWD
            })
        if settings.ENV_TYPE == 'production':
            self.session.headers.update({'Host': self.url.split(':')[1].lstrip('//')})

    def execute(self, query, query_params=None):
        if self.readonly and not query.lstrip()[:10].lower().startswith('select'):
            raise ClickhouseException('Clickhouse client is in readonly state')
        query = self._prepare_query(query, query_params=query_params)
        r = self._post(query.encode('utf-8'))
        if r.status_code != 200:
            raise ClickhouseException('{code}\n{content}\n=====\n{query}\n===='.format(
                code=r.status_code,
                content=r.content.decode('utf-8').rstrip('\n'),
                query=query))

    def insert(self, query: str, query_params=None):
        if self.readonly:
            raise ClickhouseException('Clickhouse client is in readonly state')
        query = self._prepare_query(query, query_params=query_params)
        assert query.lstrip()[:10].lower().startswith('insert')
        r = self._post(query.encode('utf-8'))
        if r.status_code != 200:
            raise ClickhouseException('{code}\n{content}\n=====\n{query}\n===='.format(
                code=r.status_code,
                content=r.content.decode('utf-8').rstrip('\n'),
                query=query)
            )

    def select(self, query: str, query_params=None, deserialize=True):
        query = self._prepare_query(query, query_params=query_params)
        query = query.rstrip().rstrip(';') + ' FORMAT JSONCompact'
        if not self.readonly and not self.url.endswith('max_threads=3'):
            if not self.local:
                self.url += '?joined_subquery_requires_alias=0&max_threads=3'
            else:
                self.url += '?max_threads=3'
        assert query.lstrip()[:10].lower().startswith('select')
        r = self._post(query)
        if r.status_code != 200:
            raise ClickhouseException('{code}\n{content}\n=====\n{query}\n===='.format(
                code=r.status_code,
                content=r.content.decode('utf-8').rstrip('\n'),
                query=query)
            )
        if deserialize:
            return json.loads(r.content.decode('utf-8'))['data']
        else:
            return r.content.decode('utf-8')

    def select_tsv(self, query: str, query_params=None) -> str:
        query = self._prepare_query(query, query_params=query_params)
        query = query.rstrip().rstrip(';')
        if not self.readonly and not self.url.endswith('max_threads=3'):
            if not self.local:
                self.url += '?joined_subquery_requires_alias=0&max_threads=3'
            else:
                self.url += '?&max_threads=3'
        # assert query.lstrip()[:10].lower().startswith('select')
        r = self._post(query.encode('utf-8'))
        if r.status_code != 200:
            raise ClickhouseException('{code}\n{content}\n=====\n{query}\n===='.format(
                code=r.status_code,
                content=r.content.decode('utf-8').rstrip('\n'),
                query=query)
            )
        return r.content.decode('utf-8')

    def _prepare_query(self, query: str, query_params=None) -> str:
        self.query_params_check(query_params)
        if query_params:
            for k in list(query_params.keys()):
                if k not in ('metric_ids', 'job_ids', 'cases_with_tag', 'tags', 'targets', 'query_addon', 'query_time'):
                    query_params[k] = escape_string(str(query_params[k]))
            query = query.format(**query_params)
        return query

    def _post(self, payload: bytes) -> requests.Response:
        return self.session.post(self.url, data=payload, timeout=10)

    @staticmethod
    def query_params_check(query_params: dict) -> None:
        rules = {
            'job': lambda p: str(p).isdigit(),
            'job_date': lambda p: str(p).isdigit(),
            'compress_ratio': lambda p: str(p).isdigit(),
        }
        if query_params is not None:
            if isinstance(query_params, dict):
                for key in list(query_params.keys()):
                    param = query_params[key]
                    if key in rules:
                        try:
                            assert rules[key](param)
                        except AssertionError:
                            raise SanitaryError('param "{}" is invalid: {}'.format(key, param))
                    else:
                        try:
                            assert not re.findall('select', str(param)) \
                                   and not re.findall('remote', str(param))
                        except AssertionError:
                            raise SanitaryError('param "{}" is invalid: {}'.format(key, param))
            else:
                raise SanitaryError('query_params must be of dict type')


class Singleton(type):
    instances = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls.instances:
            cls.instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls.instances[cls]


class MemCache(memcache.Client):

    def __init__(self, servers, fmt='json', **kwargs):
        super(MemCache, self).__init__(servers, **kwargs)
        self.fmt = fmt

    def set(self, key, val, **kwargs):
        try:
            if self.fmt == 'json':
                val = json.dumps(val)
            elif self.fmt == 'b64':
                val = b64encode(val)
            return super(MemCache, self).set(str(key), val)
        except TypeError as exc:
            logging.error('Failed to set %s value %s to memcache', type(val), val, exc_info=True)

    def get(self, key):
        val = super(MemCache, self).get(str(key))
        if val is None:
            return
        try:
            if isinstance(val, bytes):
                val = val.decode('utf-8')

            if self.fmt == 'json':
                val = json.loads(val)
            elif self.fmt == 'b64':
                val = b64decode(val)
            return val
        except TypeError:
            logging.error('Failed to get %s value %s from memcache', type(val), val, exc_info=True)
        except AttributeError:
            logging.error('Failed to decode  %s value %s from memcache', type(val), val, exc_info=True)
        except ValueError:
            logging.debug('Failed to get value from %s: ', val, exc_info=True)

    def delete(self, key, **kwargs):
        super(MemCache, self).delete(str(key))


class CacheClient(object):
    def __new__(cls, expire=2592000, fmt='json'):  # a month
        """

        :param expire:
        :return: CaaS client
        """
        instance = MemCache(
            [
                'inet6:[{}]:{}'.format(
                    '2a02:6b8:0:1416:225:90ff:feef:c620',
                    # settings.MEMCACHE_HOST,
                    settings.MEMCACHE_PORT
                )
            ], fmt=fmt
        )
        return instance


def init_redis():
    redis_hosts = REDIS_HOST.split(',')
    if len(redis_hosts) > 1:
        sentinel = Sentinel([(host, REDIS_SENTINEL_PORT) for host in redis_hosts], socket_timeout=0.1, db=REDIS_DB_NAME)
        master = sentinel.master_for('lunapark_cache_redis', password=REDIS_DB_PWD)
        slave = sentinel.slave_for('lunapark_cache_redis', password=REDIS_DB_PWD)
    else:
        db = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NAME, password=REDIS_DB_PWD)
        master, slave = db, db
    return master, slave


class StartrekClient(object, metaclass=Singleton):
    TASK_ID_REGEX = r'[a-zA-Z]+-\d+'

    def __init__(self):
        self.base_url = settings.STARTREK_URL
        self.token = settings.OAUTH_TOKEN
        self.client = Startrek(token=self.token, useragent='Lunapark', base_url=self.base_url)
        self.cache = CacheClient()

        self.session = requests.Session()
        self.session.verify = False
        self.session.headers['Authorization'] = 'OAuth {}'.format(self.token)

        self.PseudoTask = namedtuple('PseudoTask', ['key', 'resolution', 'summary'])

    def check_task_exists(self, task_id):
        """

        :param task_id:
        :return: bool
        """
        if settings.ENV_TYPE != 'production':
            return True
        cache_key = 'task_{}_exists'.format(task_id)
        in_cache = self.cache.get(cache_key)
        if not in_cache:  # None or False
            if not re.match(self.TASK_ID_REGEX, task_id):
                raise TaskError('Task id {} doesn\'t match the pattern \\{}\\'.format(task_id, self.TASK_ID_REGEX))
            try:
                st_task = self.client.issues[task_id] is not None
                self.cache.set(cache_key, True)
                return st_task
            except NotFound:
                raise TaskError('Task {} not found on StarTrek'.format(task_id))
            except Forbidden:
                self.cache.set(cache_key, True)
                return True
        else:
            return in_cache

    def create_link(self, task_id, link, relation='"relates"'):
        ENDPOINT = '/v2/issues/'
        url = urljoin(urljoin(self.base_url, ENDPOINT), task_id)
        r = requests.Request(method='LINK', url=url, headers={
            'Link': '<{}>; rel={}'.format(link, relation),
            'Authorization': 'OAuth {}'.format(self.token)
        })

        try:
            resp = self.session.send(r.prepare())
            resp.raise_for_status()
            return resp.content
        except (requests.ConnectionError, requests.Timeout):
            logging.error('StarTrek issue linking unsuccessful', exc_info=True)
        except requests.HTTPError:
            logging.error('Startrek link create: {}'.format(resp.status_code), exc_info=True)

    def get(self, query):
        """
        Get Anything that Api provides
        :param query: custom query like 'issues?filter=queue:STARTREK&order=-updated'
        :return: response body (json string hopefully)
        """
        ENDPOINT = '/v2/'
        url = urljoin(urljoin(self.base_url, ENDPOINT), query)
        r = requests.Request(method='GET', url=url)
        try:
            resp = self.session.send(r.prepare())
        except (requests.ConnectionError, requests.Timeout) as e:
            logging.error('StarTrek issue: %s', e.args)
        else:
            try:
                resp.raise_for_status()
                return resp.content
            except requests.HTTPError:
                logging.error('Startrek: {}'.format(resp.content))

    def get_task_name(self, task_id):
        return self.client.issues[task_id].summary

    def get_task(self, task_id):
        # if settings.ENV_TYPE != 'production':
        #     return self.PseudoTask(key='TEST-000', resolution=True, summary='')
        try:
            assert re.match(self.TASK_ID_REGEX, task_id)
            return self.client.issues[task_id]
        except AssertionError:
            raise TaskError('Task id {} doesn\'t match the pattern \\{}\\'.format(task_id, self.TASK_ID_REGEX))

    def get_tasks(self, keys):
        """

        :param keys: array of task keys
        :return:
        """
        try:
            return self.client.issues.find(keys=keys)
        except Exception as e:
            logging.error('StarTrek issue: ', exc_info=True)


def link_and_create_task(_, task_key):
    """

    :param _: HTTP request
    :param task_key:
    :return:
    """
    # TODO: check on empty task
    from common.models import Task
    st = StartrekClient()
    if not st.check_task_exists(task_key):
        raise TaskError('Task {} not found on StarTrek'.format(task_key))
    try:
        task, is_new = Task.objects.get_or_create(key=task_key)
    except Task.MultipleObjectsReturned:
        task = Task.objects.filter(key=task_key)[0]
    try:
        # TODO: hardcoded link
        st.create_link(task_key, 'https://lunapark.yandex-team.ru/' + task_key)
    except Exception:
        logging.error('Failed to link or create task', exc_info=True)
    return task


st_statuses_img_path = '/media/img/task_statuses/'

st_statuses = {  # https://www.flaticon.com/packs/clipboards
    '': 'status_generic.svg',
    'open': 'status_open.svg',
    'readyToDeploy': 'status_readytodeploy.svg',
    'needInfo': 'status_needinfo.svg',
    'inProgress': 'status_inprogress.svg',
    'resolved': 'status_resolved.svg',
    'closed': 'status_closed.svg',
}


class MDSClient(object):
    """
    Puts ammo into storage
    """

    def __init__(self, namespace='load-ammo'):
        self.namespace = namespace
        if self.namespace == 'load-ammo':
            auth = settings.STORAGE_AMMO_KEY
        elif self.namespace == 'load-artefact':
            auth = settings.STORAGE_ARTEFACT_KEY
        else:
            auth = ''
        self.headers = {'Authorization': auth}

    def post(self, path, data, content_length):
        """

        :param path: only last part of uri
        :param data: File object
        :param content_length: int
        """
        env = '' if settings.ENV_TYPE == 'production' else 't'
        try:
            url = settings.MDS_UPLOAD_URL.format(env, self.namespace, path)
            # https://github.com/freakboy3742/pyxero/issues/173
            self.headers['Content-Length'] = str(content_length)
            with data:
                resp = requests.post(url, headers=self.headers, data=data, verify=settings.CERT_FILE)
            if resp.status_code == 200:
                storage_url = settings.MDS_GET_URL.format(
                    env, self.namespace, ElementTree.fromstring(resp.content).attrib['key']
                )
                logging.debug(resp)
                logging.debug(resp.content)
                return {'success': True, 'url': storage_url, 'error': ''}
            else:
                logging.warning('MDS client returned http %s: %s', resp.status_code, resp.content)
                return {'success': False, 'url': None, 'error': resp.status_code}
        except Exception as exc:
            logging.error('Problems with MDS client', exc_info=True)
            return {'success': False, 'url': None, 'error': exc.__class__.__name__}

    @staticmethod
    def get(path):
        """
        supposed to be used not for ammo, but for configs in apiv2
        :param path: full mds path like "https://storage-int..."
        """
        try:
            resp = requests.get(path, verify=settings.CERT_FILE)
            resp.raise_for_status()
            return {'success': True, 'content': resp.content, 'error': ''}
        except Exception as exc:
            return {'success': False, 'content': None, 'error': exc.__class__.__name__}
