# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from functools import partial
from itertools import chain, islice, product
from os.path import join

from cachetools.func import ttl_cache

from yabus import common
from yabus.common.entities import RawSegments
from yabus.common.exceptions import InvalidRide, OfflinePaymentDisabled, PointNotFound, PartnerError
from yabus.util import decode_utf8, deduplicate, monitoring
from yabus.util.parallelism import ThreadPool, pmap
from yabus.yugavtotrans.baseclient import BaseClient
from yabus.yugavtotrans.carrier_provider import carrier_provider
from yabus.yugavtotrans.citizenships import citizenships
from yabus.yugavtotrans.converter import point_converter
from yabus.yugavtotrans.entities import Book, Endpoint, Order, Refund, RefundInfo, Ride, RideDetails, Ticket
from yabus.yugavtotrans.exceptions import parse_error
from yabus.yugavtotrans.session import Session

logger = logging.getLogger(__name__)


class Client(common.Client, BaseClient):
    converter = point_converter
    session_cls = Session
    carrier_provider = carrier_provider

    def endpoints(self):
        with ThreadPool(size=2) as pool:
            stations = deduplicate(chain.from_iterable(
                self.get('station/list/' + x)['stations'] for x in ['from', 'to']
            ), key=lambda x: (x['station_id'], x['city_id']))
        return [
            Endpoint.init(decode_utf8(x))
            for x in stations
        ]

    def segments(self):
        return list(self.converter.gen_map_segments(self._raw_segments()))

    def raw_segments(self):
        return RawSegments.init({
            'segments': self._raw_segments(),
        })

    def psearch(self, (from_sid, to_sid), date):
        params = {
            'from': from_sid,
            'to': to_sid,
            'date': date.strftime('%d.%m.%Y'),
        }
        try:
            return self.get('ride/list', params=params)
        except PartnerError:
            logger.exception('cannot get rides from_sid=%s to_sid=%s date=%s', from_sid, to_sid, date)
            return []

    def search(self, from_uid, to_uid, date, _=False):
        try:
            func = partial(self.psearch, date=date)
            sids = map(self.converter.deprecated_map, [from_uid, to_uid])
            directions = product(*sids)
            return Ride.init(list(chain.from_iterable(pmap(func, directions))), carrier_provider=self.carrier_provider)
        except PointNotFound:
            return []

    @parse_error
    def ride_details(self, ride_id):
        trip_sid, from_sid, to_sid, date = (ride_id[k] for k in ['trip_sid', 'from_sid', 'to_sid', 'date'])
        try:
            rides = self.psearch((from_sid, to_sid), date)
            ride = next(ride for ride in rides if ride['trip_id'] == trip_sid)
        except StopIteration:
            raise InvalidRide
        return RideDetails.init(ride)

    @parse_error
    def book(self, ride_id, passengers, pay_offline):
        if pay_offline:
            raise OfflinePaymentDisabled
        methods = [
            self._get_countries,
            self._get_documents,
            partial(self.post, 'ticket/reservation', json={
                'trip_id': ride_id['trip_sid'],
                'variant_id': ride_id['variant_sid'],
                'ticket': True,
                'passenger': {p.pop('seat'): p for p in map(Book.init, passengers)},
            })]
        countries, documents, booking = pmap(lambda x: x(), methods)
        self._validate(countries, documents)
        return Order.init(booking)

    @parse_error
    def confirm(self, order_id):
        tickets_id = order_id['tickets_id'].split('-')

        def fetcher(tid):
            return tid, self.raw_ticket(tid)

        sold_tickets = {tid: tkt for tid, tkt in pmap(fetcher, tickets_id) if tkt['status'] == 'payed'}
        tickets_id = [tid for tid in tickets_id if tid not in sold_tickets]
        with self.session() as session:
            response = session.post('tickets/pay', json={
                'key': session.key,
                'ticket': True,
                'tickets': {str(i): tid for i, tid in enumerate(tickets_id)}
            }) if tickets_id else {'tickets': []}
            response['tickets'] += sold_tickets.values()
            return Order.init(response)

    @parse_error
    def refund(self, ticket_id):
        ticket_sid = ticket_id['ticket_sid']
        with self.session() as session:
            url = join('ticket', ticket_sid, 'return')
            return Refund.init(session.post(url, json={
                'key': session.key,
                'ticket': True,
            }))

    def ticket(self, ticket_sid):
        return Ticket.init(self.raw_ticket(ticket_sid))

    @parse_error
    def raw_ticket(self, ticket_sid):
        return self.get(join('ticket', ticket_sid) + '/')

    def refund_info(self, ticket_id):
        return RefundInfo.init({})

    def _raw_segments(self):
        points = (x['point_name'] for x in self.get('routes/list')['routes'])
        return list((x['station_id'], y['station_id']) for xs in points
            for i, x in enumerate(xs)
            for y in islice(xs, i + 1, None)
        )

    @ttl_cache(maxsize=1, ttl=24 * 60 * 60)  # ttl = 24 hours
    def _get_countries(self):
        return self.get('citizenship/list').get('citizenships', [])

    @ttl_cache(maxsize=1, ttl=24 * 60 * 60)  # ttl = 24 hours
    def _get_documents(self):
        return self.get('document/list').get('documents', [])

    @staticmethod
    def _validate(countries, documents):
        if countries:
            countries_diff = set(citizenships) - set(country['citizenship_id'] for country in countries)
            if countries_diff:
                logger.error("citizenship dictionary has outdated keys: %s", countries_diff)
                monitoring.set_gauge(
                    'yugavtotrans.client.outdated_keys', len(countries_diff), {'dictionary': 'citizenship'}
                )
        if documents:
            documents_diff = set(RideDetails.DocumentType.conv) - set(doc['doc_id'] for doc in documents)
            if documents_diff:
                logger.error("documents dictionary has outdated keys: %s", documents_diff)
                monitoring.set_gauge(
                    'yugavtotrans.client.outdated_keys', len(documents_diff), {'dictionary': 'documents'}
                )
