# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import six
from functools import partial
from itertools import chain, permutations, product
from os.path import join

from yabus import common
from yabus.common.entities import Refund, RawSegments
from yabus.common.exceptions import OfflinePaymentDisabled, PartnerError, PointNotFound
from yabus.ecolines import defaults
from yabus.ecolines.baseclient import BaseClient
from yabus.ecolines.converter import point_converter
from yabus.ecolines.entities import Book, Endpoint, Order, RefundInfo, Ride, RideDetails, Ticket
from yabus.ecolines.exceptions import parse_error
from yabus.util import decode_utf8
from yabus.util.parallelism import pmap

logger = logging.getLogger(__name__)

_MULTILEG_JOURNEY_ERROR = 'multileg journeys not supported'


class Client(common.Client, BaseClient):
    converter = point_converter

    def endpoints(self):
        return [
            Endpoint.init(decode_utf8(x))
            for x in self.get('stops', params={'all': 1})
        ]

    def segments(self):
        return list(self.converter.gen_map_segments(self._raw_segments()))

    def raw_segments(self):
        return RawSegments.init({
            'segments': self._raw_segments(),
        })

    def psearch(self, (from_sid, to_sid), date):
        params = {
            'outboundOrigin': from_sid,
            'outboundDestination': to_sid,
            'outboundDate': date.isoformat(),
            'currency': defaults.RUB,
            'applyDiscounts': 0,
            'adults': 1,
            'children': 0,
            'teens': 0,
            'seniors': 0,
        }

        try:
            journeys = self.get('journeys', params=params)
            return Ride.init(self._remove_multilegs_rides(journeys), stations=self._stations)
        except PartnerError:
            logger.exception('cannot get rides from_sid=%s to_sid=%s date=%s', from_sid, to_sid, date)
            return []

    def search(self, from_uid, to_uid, date, _=False):
        func = partial(self.psearch, date=date)
        try:
            sids = map(self.converter.deprecated_map, [from_uid, to_uid])
            directions = product(*sids)
            return list(chain(*pmap(func, directions)))
        except PointNotFound:
            return []

    @parse_error
    def ride_details(self, ride_id):
        funcs = (self._leg, self._fares)
        leg, fares = pmap(lambda f: f(ride_id['ride_sid']), funcs)
        seats = self._seats(leg['id'])
        return RideDetails.init(
            {
                'leg': leg,
                'map': seats,
                'seats': seats,
                'ticketTypes': fares,
            },
            stations=self._stations,
        )

    @parse_error
    def book(self, ride_id, passengers, pay_offline):
        if pay_offline:
            raise OfflinePaymentDisabled
        payload = {
            'id': None,
            'journey': ride_id['ride_sid'],
            'currency': defaults.RUB,
            'passengers': map(Book.init, passengers),
        }
        order = self.post('bookings', json=payload)
        order['status'] = 'waiting'
        tickets = self._tickets_headers(order['id'])
        return Order.init(order, tickets=tickets)

    @parse_error
    def confirm(self, order_id):
        order_sid = order_id['order_sid']
        path = join('bookings', order_sid, 'confirmation')
        self.post(path)
        order = {'id': order_sid, 'status': 'confirmed'}
        tickets = self._tickets_headers(order_sid)
        return Order.init(order, tickets=tickets)

    @parse_error
    def refund_info(self, ticket_id):
        ticket_sid, order_sid = ticket_id['ticket_sid'], ticket_id['order_sid']
        path = join('bookings', order_sid, 'tickets', ticket_sid, 'cancellations')
        return RefundInfo.init(self.get(path)[0])

    def refund(self, ticket_id):
        refund = Refund.init(self.refund_info(ticket_id))
        ticket_sid, order_sid = ticket_id['ticket_sid'], ticket_id['order_sid']
        path = join('bookings', order_sid, 'tickets', ticket_sid, 'cancellations', '1', 'confirmation')
        self.post(path)
        return refund

    def ticket(self, ticket_id):
        ticket_sid, order_sid = ticket_id['ticket_sid'], ticket_id['order_sid']
        ticket_header = self._ticket_header(order_sid, ticket_sid)
        return Ticket.init(ticket_header, order_id=order_sid)

    def _raw_segments(self):
        raws = ((six.text_type(segment['origin']), six.text_type(segment['destination'])) for segment in self.get('segments'))
        return list(chain.from_iterable(permutations(segment) for segment in raws))

    def _fares(self, journey_id):
        return self.get('fares', params={'journey': journey_id, 'currency': defaults.RUB})

    def _seats(self, leg_id):
        return self.get('seats', params={'leg': leg_id})

    def _leg(self, journey_id):
        legs = self.get('legs', params={'journey': journey_id})
        if len(legs) > 1:
            raise ValueError(_MULTILEG_JOURNEY_ERROR)
        return legs[0]

    def _ticket_header(self, order_sid, ticket_sid):
        return next(th for th in self._tickets_headers(order_sid) if th['id'] == ticket_sid)

    def _tickets_headers(self, order_sid):
        return self.get(join('bookings', order_sid, 'tickets'))

    @staticmethod
    def _remove_multilegs_rides(journeys):
        return [x for x in journeys if x['outbound']['changes'] == 0]
