# -*- coding: utf-8 -*-
import functools
import gzip
import inspect
import json
import os
import re
import requests
import shutil
from urllib import quote_plus
from urlparse import urljoin

from logging import getLogger
from collections import OrderedDict
from traceback import format_exc

import gevent
from django.conf import settings

from travel.avia.library.python.avia_data.models.amadeus_merchant import AmadeusMerchant
from travel.avia.library.python.common.models.partner import DohopVendor
from travel.avia.library.python.proxy_pool.django import partner_proxy
from travel.avia.ticket_daemon.ticket_daemon.daemon.utils import BadPartnerResponse, TimedChunk
from travel.avia.ticket_daemon.ticket_daemon.lib.feature_flags import save_partners_exception_tracks
from travel.avia.ticket_daemon.ticket_daemon.lib.timer import Timer
from travel.avia.ticket_daemon.ticket_daemon.lib.ipc import lock_resource

# Suppress InsecureRequestWarning
from requests.packages.urllib3.exceptions import InsecureRequestWarning

requests.packages.urllib3.disable_warnings(InsecureRequestWarning)

log = getLogger(__name__)


class ExchangeMessage(object):
    def __init__(self, code, content, compress=True):
        self.code = code
        self.content = content
        self.compress = compress


def get_partner_code_from_filename(filename):
    filename = os.path.basename(filename)
    name, ext = os.path.splitext(filename)

    # Отбросим версии модуля
    partner_code = re.sub(r'\d*$', '', name.lower())

    return partner_code


class QueryTracker(object):
    TRACKS_PATH = os.path.join(os.path.normpath(settings.LOG_PATH), 'track')
    store_place = None

    def __init__(self, partner_code, q):
        self.partner_code = partner_code
        self.q = q
        self.exchanges = []

        self.active = self.need_store_tracks = getattr(q, 'need_store_tracks', False)

        if save_partners_exception_tracks():
            self.active = True

        use_cache = getattr(q, 'tracker_use_cache', False)
        if use_cache:
            self.cache_place = self.find_cache_place()
            if self.cache_place:
                self.need_store_tracks = False
                self.active = True
        else:
            self.cache_place = None
        self.timer = Timer()
        self.query_time = .0

        if hasattr(q, 'trackers'):
            q.trackers[self.key] = self
            self.active = True

        self.parsing_exception = None

    @property
    def key(self):
        # q.id already contains q.key
        return '%s#%s' % (self.q.id, self.partner_code)

    @classmethod
    def init_query(cls, query_fn):
        partner_filename = query_fn.func_code.co_filename
        partner_code = get_partner_code_from_filename(partner_filename)

        @functools.wraps(query_fn)
        def query_wrapper(q):
            code = partner_code

            if hasattr(q, 'importer') and q.importer.partners:
                partner = q.importer.partners[0]

                if isinstance(partner, DohopVendor):
                    code = 'dohop'
                elif isinstance(partner, AmadeusMerchant):
                    code = 'amadeus'
                else:
                    code = partner.code

            tracker = cls(code, q)
            return tracker.call_query_fn(query_fn, q)

        return query_wrapper

    def call_query_fn(self, query_fn, q):
        def chunks_wrapper(tracker):
            try:
                query_result = query_fn(self, q)

                if inspect.isgenerator(query_result):
                    chunks = query_result
                else:
                    chunks = iter([query_result])

                for chunk in chunks:
                    yield TimedChunk(chunk, tracker.query_time)
            except Exception:
                self.parsing_exception = self.parsing_exception or 'parsing_exception'
                raise
            finally:
                if tracker.active:
                    tracker.done()

        return chunks_wrapper(self)

    def wrap_request(self, method, *args, **kwargs):
        exchange = self.add_exchange(title=kwargs.pop('exchange_title', ''))

        if self.cache_place:
            def method(*margs, **mkwargs):
                r = requests.Response()
                r.status_code = 200
                r._content = self.load_gzipped(
                    self.cache_place, exchange.key, 'response_content')
                return r

        resp = exchange.wrap_request(method, *args, **kwargs)

        self.set_query_time_from_timer()
        gevent.sleep(0)
        self.check_response(resp)

        return resp

    def set_query_time_from_timer(self):
        self.query_time = self.timer.get_elapsed_seconds()

    def check_response(self, resp):
        if not isinstance(resp, requests.Response):
            return

        if not resp.ok:
            self.parsing_exception = '%d_status_code' % resp.status_code
            raise BadPartnerResponse(self.partner_code, resp)

    def add_exchange(self, title=''):
        exchange = Exchange(self, id_=len(self.exchanges) + 1, title=title)

        self.exchanges.append(exchange)

        return exchange

    def done(self):
        try:
            if self.parsing_exception:
                key = 'bad/%s/%s' % (self.partner_code, self.parsing_exception)

                with lock_resource(key):
                    _exceptions_store_path = self.prepare_store_place(key)
                    self._store(_exceptions_store_path)

            if self.need_store_tracks:
                with lock_resource(self.key):
                    self.store_place = self.prepare_store_place(self.key)
                    self._store(self.store_place)

        except Exception:
            log.exception('Store error')

    def find_cache_place(self):
        if not os.path.exists(self.TRACKS_PATH):
            return
        key_part = self.key.split('.', 1)[-1]
        candidates = sorted([
            n for n in os.listdir(self.TRACKS_PATH) if key_part in n
        ])
        if not candidates:
            return
        return os.path.join(self.TRACKS_PATH, candidates[-1])

    def _store(self, store_path):
        for exchange in self.exchanges:
            for exchange_event in exchange.events:
                title = '{}_{}'.format(exchange.key, exchange_event.code)

                if exchange_event.compress:
                    self.save_gzipped(store_path, title, exchange_event.content)

                else:
                    self.save_plain(store_path, title, exchange_event.content)

    @staticmethod
    def load_gzipped(path, ekey, event_code):
        title = '{}_{}'.format(ekey, event_code)
        filename = os.path.join(path, title + '.txt.gz')
        with gzip.open(filename, 'rb') as f:
            return f.read()

    @staticmethod
    def save_gzipped(store_path, title, content):
        filename = os.path.join(store_path, title + '.txt.gz')

        with gzip.open(filename, 'wb') as f:
            f.write(content)

    @staticmethod
    def save_plain(store_path, title, content):
        filename = os.path.join(store_path, title + '.txt')

        with open(filename, 'wb') as f:
            f.write(content)

    @staticmethod
    def prepare_store_place(key):
        if not os.path.exists(QueryTracker.TRACKS_PATH):
            os.makedirs(QueryTracker.TRACKS_PATH)

        store_place = os.path.join(QueryTracker.TRACKS_PATH, key)

        if os.path.exists(store_place):
            shutil.rmtree(store_place, ignore_errors=True)

        os.makedirs(store_place)

        return store_place


class Exchange(object):
    def __init__(self, tracker, id_, title):
        self.tracker = tracker
        self.id = id_
        self.title = title
        self.events = []

    @property
    def key(self):
        return '{}{}'.format(self.id, self.title)

    def as_header(self):
        return quote_plus(self.key.encode('ASCII', 'replace'))

    def wrap_request(self, method, *args, **kwargs):
        """
        method - Обязательно из библиотеки requests.
        Может быть методом инстанса сессии из requests.
        """

        if settings.STRESS_PARTNERS_EMULATOR_URL:
            args = list(args)
            args[0] = urljoin(
                settings.STRESS_PARTNERS_EMULATOR_URL,
                self.tracker.partner_code
            )

        kwargs['timeout'] = settings.PARTNER_QUERY_TIMEOUT

        xheaders = {
            'X-Query': self.tracker.q.as_header(),
            'X-Exchange': self.as_header(),
        }

        if 'headers' in kwargs:
            kwargs['headers'].update(xheaders)

        else:
            kwargs['headers'] = xheaders

        request_timer = Timer()

        try:
            r = partner_proxy.request(method, *args, **kwargs)

            if self.tracker.active and not self.tracker.cache_place:
                self._add_tracks(r, request_timer, args)

            return r

        except Exception:
            request_params = json.dumps(
                [args, kwargs], indent=4, separators=(',', ': '), default=repr
            )

            self.events.append(
                ExchangeMessage('request_params', request_params)
            )
            self.events.append(ExchangeMessage('error', format_exc()))

            raise

    def _add_tracks(self, response, request_timer, request_args):
        self.events.append(ExchangeMessage(
            'elapsed', str(request_timer.get_elapsed_seconds()),
            compress=False
        ))

        if isinstance(response, requests.Response):
            self.events.append(ExchangeMessage(
                'elapsed-r', str(response.elapsed.total_seconds()),
                compress=False
            ))

        self.events.append(
            ExchangeMessage('request_response', dump_response(response, *request_args)))

        if hasattr(response, 'content'):
            self.events.append(ExchangeMessage('response_content', response.content))

    def __str__(self):
        return self.key


def dump_response(r, *args):
    params = OrderedDict()

    if len(args):
        params['url'] = args[0]

        if len(args) > 1:
            params['args'] = args[1:]

    params.update(requests_response_params(r))

    return json.dumps(
        params,
        indent=4,
        separators=(',', ': '),
        ensure_ascii=False,
    )


def requests_response_params(r):
    if not isinstance(r, requests.Response):
        return {}

    params = OrderedDict()

    response_details = OrderedDict()

    if hasattr(r, 'headers'):
        response_details['headers'] = dict(r.headers)

    if hasattr(r, 'reason'):
        response_details['reason'] = r.reason.encode('string-escape')

    if hasattr(r, 'content'):
        response_details['content'] = r.content.encode('string-escape')

    response_keys = [
        'status_code',
        'ok',
        'encoding',
    ]

    for key in response_keys:
        value = getattr(r, key, None)

        if value is not None:
            response_details[key] = value

    request_keys = [
        'auth',
        'body',
        'data',
        'full_url',
        'method',
        'params',
        'url',
    ]

    for key in request_keys:
        value = getattr(r.request, key, None)

        if value is not None:
            params[key] = value

    params['headers'] = dict(r.request.headers)

    if response_details:
        params['response_details'] = response_details

    return params
