# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging

import requests
import six
import io
from flask import send_file
from itertools import chain, product

from cachetools.func import ttl_cache
from zeep.exceptions import TransportError

from yabus import common
from yabus.common.exceptions import OfflinePaymentDisabled, PartnerError, PointNotFound, InvalidTicket
from yabus.common.entities import RawSegments
from yabus.etraffic.baseclient import BaseClient
from yabus.etraffic.converter import point_converter
from yabus.etraffic.defaults import TICKET_PDF, RPS_LIMIT, SUPPLIER_CODE
from yabus.etraffic.entities import (
    AtpRideDetails,
    Book,
    Endpoint,
    Order,
    Refund,
    RefundInfo,
    Ride,
    RideDetails,
)
from yabus.etraffic.exceptions import parse_error
from yabus.etraffic.segments_provider import segments_provider
from yabus.util import decode_utf8
from yabus.util.bottleneck import bottleneck
from yabus.util.formatting import python
from yabus.util.parallelism import pmap, ThreadPool
from yabus.util.retry import retry

logger = logging.getLogger(__name__)


class Client(common.Client, BaseClient):
    converter = point_converter

    def endpoints(self):
        depots, segments = pmap(lambda f: f(), [
            self._get_depots,
            self._get_segments
        ])
        endpoints = self._make_endpoints(chain.from_iterable(segments), depots)
        return [
            Endpoint.init(decode_utf8(python(x)))
            for x in endpoints
        ]

    def segments(self):
        return list(point_converter.gen_map_segments(self._raw_segments()))

    def raw_segments(self):
        return RawSegments.init({
            'segments': self._raw_segments(),
        })

    def search(self, from_uid, to_uid, date, _=False):
        def psearch(direction):
            from_sid, to_sid = direction
            payload = [
                from_sid,
                to_sid,
                date,
            ]
            try:
                response = python(self.call('getRaces', payload))
            except PartnerError as e:
                logger.exception('cannot get rides from_sid=%s to_sid=%s date=%s: %s', from_sid, to_sid, date, e)
                return []
            rides = Ride.init(response)
            return common.respfilters.coherence(SUPPLIER_CODE, from_uid, to_uid, from_sid, to_sid, date, rides)

        try:
            # map with use_relations returns direct map of the point or map of its children if the point is not mapped
            # map_to_children returns map of children only if the point is mapped
            from_ids = self.converter.map(from_uid, use_relations=True) | self.converter.map_to_children(from_uid)
            to_ids = self.converter.map(to_uid, use_relations=True) | self.converter.map_to_children(to_uid)
        except PointNotFound:
            return []

        directions = product(from_ids, to_ids)
        cached_segments = segments_provider.get_segments()
        if cached_segments:
            directions = (direction for direction in directions if direction in cached_segments)

        return list(chain.from_iterable(pmap(psearch, directions)))

    @parse_error
    def ride(self, ride_id):
        response = python(self.call('getRace', [ride_id['ride_sid']]))
        return Ride.init(response)

    @parse_error
    def ride_details(self, ride_id):
        ride_details_cls = AtpRideDetails if ride_id.get('atp') else RideDetails
        response = python(self.call('getRaceSummary', [ride_id['ride_sid']]))
        return ride_details_cls.init(response, citizenships=self.countries)

    @parse_error
    def book(self, ride_id, passengers, pay_offline):
        if pay_offline:
            raise OfflinePaymentDisabled
        passengers = [Book.init(x) for x in passengers]
        response = python(self.call('bookOrder', [ride_id['ride_sid'], passengers]))
        return Order.init(response, tickets_fwd={
            '__countries__': self.countries,
            '__order_sid__': response['id'],
        })

    @parse_error
    def confirm(self, order_id):
        response = self._order_info(order_id['order_sid'])
        if response['status'] != 'S':  # S is for sold
            response = python(self.call('confirmOrder', [order_id['order_sid']]))
        return Order.init(response, tickets_fwd={
            '__countries__': self.countries,
            '__order_sid__': response['id'],
            '__gen_url__': True,
        })

    @parse_error
    def order(self, order_id):
        response = self._order_info(order_id['order_sid'])
        return Order.init(response, tickets_fwd={
            '__countries__': self.countries,
            '__order_sid__': response['id'],
            '__gen_url__': True,
        })

    def ticket_blank(self, ticket_id):
        ticket_sid = ticket_id['ticket_sid']
        ticket_hash = None
        try:
            order_sid = ticket_id['order_sid']
        except KeyError:
            # TODO: remove fallback to old @id in future
            ticket_hash = ticket_id['ticket_hash']

        if ticket_hash is None:
            response = self._order_info(order_sid)
            for ticket in response['tickets']:
                if ticket['id'] == ticket_sid:
                    ticket_hash = ticket['hash']
        if ticket_hash is None:
            raise InvalidTicket("ticket \"{}\" not found in order".format(ticket_id))

        ticket_resp = requests.get(TICKET_PDF.format(host=self.host, version=self.version, hash=ticket_hash))
        ticket_resp.raise_for_status()
        blank_binary = ticket_resp.content
        return send_file(
            io.BytesIO(blank_binary),
            mimetype='application/octet-stream',
            as_attachment=True,
            attachment_filename='ticket_{}.pdf'.format(ticket_sid)
        )

    @parse_error
    def refund_info(self, ticket_id):
        return RefundInfo.init({})

    @parse_error
    def refund(self, ticket_id):
        response = python(self.call('returnTicket', [ticket_id['ticket_sid']]))
        return Refund.init(response)

    def _order_info(self, order_sid):
        return python(self.call('getOrder', [order_sid]))

    def change_ride_endpoints(self, ride_id, pickup_sid, discharge_sid):
        values = ride_id['ride_sid'].split(':')
        for i, code in [(-2, pickup_sid), (-1, discharge_sid)]:
            if code is not None:
                values[i] = code
        return dict(ride_id, ride_sid=':'.join(map(str, values)))

    def _raw_segments(self):
        raw_segments = [(six.text_type(x['id']), six.text_type(y['id'])) for x, y in self._get_segments()]
        segments_provider.store_segments(raw_segments)
        return raw_segments

    def _get_segments(self):
        countries = [x for x in self._get_countries() if x['code'] == 'RU']
        with ThreadPool(size=10) as pool:
            regions = chain.from_iterable(
                pool.imap_unordered(lambda x: self._get_regions(x['id']), countries)
            )
            dispatches = chain.from_iterable(
                pool.imap_unordered(lambda x: self._get_dispatch_points(x), regions)
            )
            return list(chain.from_iterable(pool.imap_unordered(
                lambda x: ((x, y) for y in self._get_arrival_points(x)),
                dispatches,
            )))

    @ttl_cache(maxsize=128)
    @retry(5, TransportError, 3)
    def _get_countries(self):
        return self.call('getCountries', [])

    @ttl_cache(maxsize=256)
    @retry(5, TransportError, 3)
    def _get_regions(self, country_id):
        return self.call('getRegions', [country_id])

    @ttl_cache(maxsize=1024)
    @bottleneck(calls=RPS_LIMIT)  # rps
    @retry(5, TransportError, 3)
    def _get_dispatch_points(self, region):
        dispatches = self.call('getDispatchPoints', [region['id']])
        for dispatch in dispatches:
            dispatch['region_info'] = region
        return dispatches

    @ttl_cache(maxsize=8096)
    @bottleneck(calls=RPS_LIMIT)  # rps
    @retry(5, TransportError, 3)
    def _get_arrival_points(self, dispatch):
        return self.call('getArrivalPoints', [dispatch['id']])

    @ttl_cache(maxsize=1024)
    @retry(5, TransportError, 3)
    def _get_depots(self):
        return self.call('getDepots', [])

    def _make_endpoints(self, endpoints, depots):
        depots_dict = {}
        for dpt in depots:
            depots_dict[dpt['id']] = dpt

        endpoints_dict = {}
        for endpoint in endpoints:
            supplier_point_id = endpoint['id']
            stored_endpoint = endpoints_dict.get(supplier_point_id)
            if stored_endpoint is None or ('region_info' in endpoint and 'region_info' not in stored_endpoint):
                endpoints_dict[supplier_point_id] = endpoint

        for supplier_point_id, endpoint in endpoints_dict.items():
            depot = depots_dict.pop(supplier_point_id, None)
            if depot is not None:
                endpoint['address'] = depot['address']
                endpoint['timezone'] = depot['timezone']
            yield endpoint

        for depot in depots_dict.values():
            yield depot
