# coding=utf-8
from __future__ import unicode_literals, absolute_import

import logging
import socket
import struct
import time
import zlib
from datetime import datetime, timedelta
from lxml import etree
from typing import Type

import six

from retrying import Retrying
from travel.library.python.safexml.safe_xml_parser import safe_xml_fromstring
from travel.library.python.xmlutils.xmlutils import lxml_humanize


STUDY_PORT = 34323
TEST_PORT = 34322
PRODUCTION_PORT = 34321
DEFAULT_HOST = '193.104.87.251'
DEFAULT_HOST_IPv6 = '64:ff9b::193.104.87.251'
SCHEDULE_PERIOD = timedelta(300)

HEADER_LENGTH = 100
SIRENA_TIMEOUT = 150
IS_MESSAGE_ZIPPED_MASK = 0x04
CAN_ZIP_ANSWER_MASK = 0x10

PORTNAMES = dict(
    study=STUDY_PORT,
    test=TEST_PORT,
    production=PRODUCTION_PORT,
)

DESCRIBE_COST = 1
DESCRIBE_MESSAGE_TEMPLATE = '''
<?xml version="1.0" encoding="UTF-8"?>
<sirena><query>
    <describe>
        <data>{}</data>
    </describe>
</query></sirena>
'''.strip()

NORMATIVE_SCHEDULE_COST = 500
NORMATIVE_SCHEDULE_MESSAGE_TEMPLATE = '''
<?xml version="1.0" encoding="UTF-8"?>
<sirena><query>
    <get_schedule2>
        <date>{start_date}</date>
        <date2>{end_date}</date2>
        <company>{company_code}</company>
    </get_schedule2>
</query></sirena>
'''.strip()

COMPANY_ROUTES_COST = 0
COMPANY_ROUTES_MESSAGE_TEMPLATE = '''
<?xml version="1.0" encoding="UTF-8"?>
<sirena><query>
    <get_company_routes>
        <company>{company}</company>
        <answer_params>
            <lang>{lang}</lang>
        </answer_params>
    </get_company_routes>
</query></sirena>
'''.strip()

logger = logging.getLogger(__name__)


class SirenaDownloadException(Exception):
    pass


class SirenaSocketClosedPrematurely(SirenaDownloadException):
    pass


class SirenaAnswerError(SirenaDownloadException):
    def __init__(self, error_text, error_code=None, error_xml=None):
        self.error_code = int(error_code)
        self.error_text = six.ensure_text(error_text)
        self.error_xml = six.ensure_text(error_xml)

    def __str__(self):
        return six.ensure_str(
            '{} #{}: {}\n{}'.format(
                self.__class__.__name__,
                self.error_code,
                self.error_text,
                self.error_xml,
            )
        )

    def __repr__(self):
        return six.ensure_str(
            '{}({}, error_code={}, error_xml={}'.format(
                self.__class__.__name__,
                repr(self.error_text),
                repr(self.error_code),
                repr(self.error_xml),
            ))

    def __unicode__(self):
        return six.ensure_text(
            '{} #{}: {}\n{}'.format(
                self.__class__.__name__,
                self.error_code,
                self.error_text,
                self.error_xml,
            )
        )


class SirenaBrandsNotFoundError(SirenaAnswerError):
    pass


def sirena_answer_error_for_code(code):
    # type: (int)->Type[SirenaAnswerError]
    sirena_answer_errors = {
        33518: SirenaBrandsNotFoundError,
    }
    return sirena_answer_errors.get(code, SirenaAnswerError)


class SirenaClient(object):
    """ Sirena client over tcp socket with compression support"""

    def __init__(self, host, port_name, client_id):
        # type: (str, str, int)->None
        """
        :param host: sirena host. See DEFAULT_HOST or DEFAULT_HOST_IPv6 for appropriate default value
        :param port_name: key for PORTNAMES dict. Use 'production' as default
        :param client_id: id given by a partner
        """
        if not port_name:
            raise ValueError('Please choose sirena port')
        if not client_id:
            raise ValueError('Please specify the client id')

        self._message_id = 0
        self._host = host
        self._port = PORTNAMES[port_name]
        self._client_id = client_id

        if not self._host:
            self._host = DEFAULT_HOST

        self._socket = None

    def get_company_routes(self, company_code, lang):
        return self._parse_sirena_routes(
            self.get(COMPANY_ROUTES_MESSAGE_TEMPLATE.format(company=company_code, lang=lang), cost=COMPANY_ROUTES_COST))

    def _parse_sirena_routes(self, response):
        def _sirena_routes(routes_tree):
            res = dict()
            for departure in routes_tree.findall('departure'):
                d_code = departure.get('code')
                arrivals = list()
                for arrival in departure.findall('arrival'):
                    arrivals.append(arrival.text)
                res[d_code] = arrivals
            return res

        logger.info(response)
        tree = etree.fromstring(response)
        return _sirena_routes(tree.xpath('./answer/get_company_routes')[0])

    def get_airlines_reference(self):
        return self.get(DESCRIBE_MESSAGE_TEMPLATE.format('aircompany'), cost=DESCRIBE_COST)

    def get_transport_models_reference(self):
        return self.get(DESCRIBE_MESSAGE_TEMPLATE.format('vehicle'), cost=DESCRIBE_COST)

    def get_schedule(self, company_code, start_date, end_date):
        return self.get(
            NORMATIVE_SCHEDULE_MESSAGE_TEMPLATE.format(
                company_code=company_code,
                start_date=start_date.strftime('%d.%m.%y'),
                end_date=end_date.strftime('%d.%m.%y')
            ),
            cost=NORMATIVE_SCHEDULE_COST
        )

    def get_normative_schedule(self, company_code):
        start_date = datetime.now()
        end_date = start_date + SCHEDULE_PERIOD
        return self.get_schedule(company_code, start_date, end_date)

    def next_message_id(self):
        self._message_id += 1
        return self._message_id

    def make_full_message(self, message, timestamp, message_id):
        flags_first_byte = CAN_ZIP_ANSWER_MASK
        zip_message = zlib.compress(message)
        if len(zip_message) < len(message):
            message = zip_message
            flags_first_byte |= IS_MESSAGE_ZIPPED_MASK
        flags_second_byte = 0
        header_parts = (
            [
                len(message),
                timestamp,
                message_id,
            ] + [b'\00'] * 32 + [
                self._client_id,
                flags_first_byte,
                flags_second_byte,
                0,  # key_id
            ] + [b'\00'] * 48
        )
        header = struct.pack(b'!iii32chbbi48c', *header_parts)
        assert len(header) == HEADER_LENGTH, 'Header length is {} instead of {}.'.format(len(header), HEADER_LENGTH)
        return header + message

    def _receive(self, sock, length):
        all_data = b''
        while len(all_data) < length:
            new_data = sock.recv(length - len(all_data))
            if not new_data:
                raise SirenaSocketClosedPrematurely(
                    six.ensure_str('Сирена закрыла сокет до того, как мы считали все данные.'),
                    len(all_data),
                    all_data,
                )
            all_data += new_data
        return all_data

    def _check_answer(self, answer):
        root_el = safe_xml_fromstring(answer)
        error_el = root_el.find('.//error')
        if error_el is not None:
            try:
                error_code = int(error_el.get('code', ''))
            except ValueError:
                error_code = None

            error_text = error_el.text
            error_xml = lxml_humanize(root_el)

            raise sirena_answer_error_for_code(error_code)(
                error_text,
                error_code=error_code,
                error_xml=error_xml,
            )

    def close(self):
        self.close_socket(self._socket)

    def close_socket(self, sock):
        if not sock:
            return
        sock.close()
        if sock == self._socket:
            self._socket = None

    def _get_socket(self, timeout):
        if not self._socket:
            self._socket = socket.create_connection(
                (self._host, self._port),
                timeout=timeout,
            )

        return self._socket

    def get(self, request_msg, cost, timeout=SIRENA_TIMEOUT, close_socket=True):
        return Retrying(
            stop_max_attempt_number=5,
            wait_fixed=2000,
            retry_on_exception=(socket.error,)
        ).call(self._try_get, request_msg, cost, timeout=timeout, close_socket=close_socket)

    def _try_get(self, request_msg, cost, timeout=SIRENA_TIMEOUT, close_socket=True):
        request_msg = six.ensure_binary(request_msg)
        timestamp = int(time.time())
        message_id = self.next_message_id()
        logger.info('Начинаем запрос %s-%s к Сирене. Баллы: %s', message_id, timestamp, cost)
        logger.info('Тело запроса: %s', request_msg.decode('utf-8'))
        logger.info('Хост: %s:%s', self._host, self._port)
        full_msg = self.make_full_message(request_msg, timestamp=timestamp, message_id=message_id)
        sock = self._get_socket(timeout)
        try:
            sock.sendall(full_msg)

            header = self._receive(sock, HEADER_LENGTH)

            length = struct.unpack_from(b'!i', header, offset=0)[0]
            data = self._receive(sock, length)
        except socket.error:
            self.close_socket(sock)
            raise
        finally:
            if close_socket:
                self.close_socket(sock)

        timestamp_from_answer, message_id_from_answer = struct.unpack_from(b'!ii', header, offset=4)
        if timestamp != timestamp_from_answer or message_id != message_id_from_answer:
            self.close_socket(sock)
            raise SirenaDownloadException('Ответ Сирены не совпадает по timestamp или message_id.')

        flags_first_byte = struct.unpack_from(b'!b', header, offset=46)[0]
        if flags_first_byte & IS_MESSAGE_ZIPPED_MASK:
            data = zlib.decompress(data)

        logger.info('Сирена считает, что запрос сделан. Баллы: %s', cost)

        try:
            self._check_answer(data)
        except SirenaAnswerError:
            self.close_socket(sock)
            raise

        logger.info('Успешно сделали запрос к Сирене. Баллы: %s', cost)
        return data
