# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import gevent
import json
import logging
from flask import Response
from itertools import chain, product

from yabus.util import decode_utf8, deduplicate
from yabus.util.bottleneck import bottleneck
from yabus.util.parallelism import ThreadPool, pmap

from yabus import common
from yabus.common.entities import RawSegments
from yabus.common.exceptions import OfflinePaymentDisabled, PointNotFound, PartnerError
from yabus.unitiki.baseclient import BaseClient
from yabus.unitiki.converter import point_converter
from yabus.unitiki.defaults import SUPPLIER_CODE
from yabus.unitiki.entities import Book, Endpoint, Order, Refund, RefundInfo, Ride, RideDetails, Ticket
from yabus.unitiki.exceptions import parse_error


logger = logging.getLogger(__name__)


class Client(common.Client, BaseClient):
    converter = point_converter
    SEARCH_TIMEOUT = 40
    SEARCH_POLLING_INTERVAL = 1

    def endpoints(self):
        with ThreadPool(size=4) as pool:
            endpoints = chain(
                self._endpoints('city', pool),
                self._endpoints('station', pool),
            )
        return [
            Endpoint.init(decode_utf8(x))
            for x in endpoints
        ]

    def segments(self):
        return list(set(point_converter.gen_map_segments(self._raw_segments())))

    def raw_segments(self):
        return RawSegments.init({
            'settlements_only': True,
            'segments': self._raw_segments(),
        })

    def psearch(self, from_sid, to_sid, date):
        start_key = 'city_id_start'
        if from_sid.startswith('s'):
            start_key = 'station_id_start'
        end_key = 'city_id_end'
        if to_sid.startswith('s'):
            end_key = 'station_id_end'
        payload = {
            start_key: from_sid[1:],
            end_key: to_sid[1:],
            'date': date.strftime('%Y-%m-%d'),
        }
        try:
            response = self.get('ride/list', payload)
            return Ride.init(response.get('ride_list', []))
        except Exception:
            logger.exception('Unitiki error')
            return []

    def realtime_search(self, from_sid, to_sid, date):
        # only cities can be in from_sid and to_sid
        if from_sid[0] == 's' or to_sid[0] == 's':
            return []
        search_payload = {
            'city_id_start': from_sid[1:],
            'city_id_end': to_sid[1:],
            'date': date.strftime('%Y-%m-%d'),
        }
        try:
            create_res = self.post('ride/search/request/create', search_payload)
        except PartnerError as exc:
            if exc.fault.get('error_code') != 2:  # is a known code for "Invalid city_id_start" or "Invalid city_id_end"
                logger.info('unitiki realtime search error: %s', exc)
            return []
        ride_search = create_res.get('ride_search')
        if not ride_search:
            logger.error('unexpected response format in unitiki realtime search, absent "ride_search"')
            return []
        search_id = ride_search.get('search_id')
        ready = ride_search.get('status')
        rides = create_res.get('ride_list', [])
        with gevent.Timeout(self.SEARCH_TIMEOUT) as timeout:
            try:
                while not ready:
                    gevent.sleep(self.SEARCH_POLLING_INTERVAL)
                    search_result = self.get('ride/search/result', {'search_id': search_id})
                    ready = search_result.get('ride_search')['status']
                    rides = search_result.get('ride_list', []) or rides

            except gevent.Timeout as t:
                if t is not timeout:
                    raise
                cancel_result = self.post('ride/search/request/cancel', {'search_id': search_id})
                rides = cancel_result.get('ride_list') or rides
                logger.error('timeouted unitiki realtime search, from_sid=%s to_sid=%s date=%s', from_sid, to_sid, date)
        return Ride.init(rides)

    def search(self, from_uid, to_uid, date, try_no_cache=False):
        try:
            def func(route):
                from_sid, to_sid = route
                rides = self._get_search_func(from_uid, to_uid, try_no_cache)(from_sid, to_sid, date)
                return common.respfilters.coherence(SUPPLIER_CODE, from_uid, to_uid, from_sid, to_sid, date, rides)

            sids = map(point_converter.map, [from_uid, to_uid])
            #  only station-station or city-city searches allowed, city-city in realtime
            directions = ((x, y) for x, y in product(*sids) if x[0] == y[0])
            retval = pmap(func, directions)
            return list(chain.from_iterable(retval))
        except PointNotFound:
            return []

    @parse_error('ride_details')
    def ride_details(self, ride_id):
        payload = {
            'ride_segment_id': ride_id['ride_sid'],
        }
        responses = [
            self.get('ride', dict(payload, **{
                'from_cache': '0',
            }))
        ] + pmap(lambda x: self.get(x, payload), [
            'ride/position/free',
            'ride/card_identity/list',
            'ride/bus/scheme/place',
        ])
        return RideDetails.init(
            {k: v for x in responses for k, v in x.items()},
            citizenships=self.countries,
        )

    @parse_error('book')
    def book(self, ride_id, passengers, pay_offline):
        if pay_offline:
            raise OfflinePaymentDisabled
        payload = {
            'ride_segment_id': ride_id['ride_sid'],
            'ticket_data': json.dumps([Book.init(x) for x in passengers]),
        }
        response = self.post('operation/booking/tmp', payload)
        return Order.init(response['operation'], tickets_fwd={
            '__countries__': self.countries,
        })

    @parse_error('confirm')
    def confirm(self, order_id):
        payload = {
            'operation_id': order_id['order_sid'],
        }
        with self.session() as session:
            response = session.post('operation/buy', payload)
            return Order.init(response['operation'], tickets_fwd={
                '__countries__': self.countries,
                '__gen_url__': True,
            })

    @parse_error('order')
    def order(self, order_id):
        payload = {
            'operation_id': order_id['order_sid'],
        }
        with self.session() as session:
            response = session.get('operation/', payload)
            return Order.init(response['operation'], tickets_fwd={
                '__countries__': self.countries,
                '__gen_url__': True,
            })

    @parse_error('refund')
    def refund_info(self, ticket_id):
        payload = {
            'operation_id': ticket_id['order_sid'],
            'ticket_id': ticket_id['ticket_sid'],
        }
        response = self.get('ticket/refund/calc', payload)
        return RefundInfo.init(response.get('ticket_refund', {}))

    @parse_error('refund')
    def refund(self, ticket_id):
        payload = {
            'operation_id': ticket_id['order_sid'],
            'ticket_id': ticket_id['ticket_sid'],
        }
        response = self.post('ticket/refund', payload)
        return Refund.init(response['ticket_refund'])

    @parse_error('ticket')
    def ticket(self, ticket_id):
        payload = {
            'ticket_id': ticket_id['ticket_sid'],
        }
        response = self.get('ticket', payload)
        return Ticket.init(response['ticket'], **{
            'countries': self.countries,
            'gen_url': True,
        })

    def ticket_blank(self, ticket_id):
        operation = self.get('operation', {
            'operation_id': ticket_id['order_sid'],
        })
        blank_content = self.get('operation/pdf', {
            'operation_id': ticket_id['order_sid'],
            'operation_hash': operation['operation']['hash'],
        }, raw_content=True)
        return Response(blank_content, mimetype='application/pdf')

    def _raw_segments(self):
        type_ = 'city'
        with ThreadPool(size=10) as pool, self.session() as session:
            dispatches = session.get('{}/list/from'.format(type_), params={})['{}_list'.format(type_)]
            prefix = type_[0]
            id_key = '{}_id'.format(type_)
            return list(chain.from_iterable(pool.map(
                lambda x: (
                    (prefix + x[id_key], prefix + y[id_key])
                    for y in self._get_arrival_endpoints(type_, x[id_key], session)
                ),
                dispatches,
            )))

    @bottleneck(calls=50)  # rps
    def _get_arrival_endpoints(self, type_, departure_endpoint, session):
        return session.get('{}/list/to'.format(type_), params={
            '{}_id_start'.format(type_): departure_endpoint
        })['{}_list'.format(type_)]

    def _endpoints(self, type_, pool):
        endpoints = deduplicate(chain.from_iterable(
            self.get('{}/list/{}'.format(type_, x), {})['{}_list'.format(type_)]
            for x in ['from', 'to']
        ), key=lambda x: x['{}_id'.format(type_)])
        prefix = type_[0]
        return (
            dict(
                endpoint,
                id=prefix + endpoint['{}_id'.format(type_)],
                title=endpoint['{}_title'.format(type_)],
                country_code=endpoint.get('country_iso3166') or endpoint.get('country_iso'),
                type=type_,
            )
            for endpoint in endpoints
        )

    def _get_search_func(self, from_sid, to_sid, try_no_cache):
        """
        make fallback to unitiki cached search in case of stations in params and one of them without parents,
        because realtime search works only with cities
        """
        if not try_no_cache:
            return self.psearch
        if from_sid[0] == to_sid[0] == 's':
            from_parents = list(point_converter.relations_provider.get_parents(from_sid))
            to_parents = list(point_converter.relations_provider.get_parents(to_sid))
            if not from_parents or not to_parents:
                return self.psearch
        return self.realtime_search
