# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import datetime
import io
import logging
from copy import deepcopy
from functools import partial
from itertools import chain, ifilter, product

import requests
from flask import send_file

from yabus import common
from yabus.common.entities import RawSegments
from yabus.common.exceptions import (
    InvalidIdentifier, InvalidRide, InvalidTicket, PartnerError, PointNotFound, RideHasNoSubstitution
)
from yabus.util.parallelism import pmap

from yabus.atlasbus.converter import point_converter
from yabus.atlasbus.entities import Book, Endpoint, Order, Refund, RefundInfo, Ride, RideDetails
from yabus.atlasbus.session import Session

logger = logging.getLogger(__name__)


class Client(common.Client, common.HttpClient):
    session_cls = Session
    converter = point_converter

    def endpoints(self):
        settlements = self.get('rides/endpoints')
        stations = self.get('rides/stations')
        return list(chain(
            Endpoint.init(settlements, type='city'),
            Endpoint.init(stations, type='station'),
        ))

    def _raw_segments(self):
        return [(s['from'], s['to']) for s in self.get('rides/segments')]

    def raw_segments(self):
        return RawSegments.init({
            'segments': self._raw_segments(),
        })

    def segments(self):
        raw_segments = self._raw_segments()
        return list(self.converter.gen_map_segments(raw_segments))

    @staticmethod
    def _full_station_id(connector, short_id):
        return '{}:{}'.format(connector, short_id)

    @staticmethod
    def _short_station_id(full_id):
        return full_id[full_id.index(':') + 1:]

    def _map_to_settlements(self, uid):
        c_uids = []
        if uid.startswith('s'):
            c_uids = self.converter.map_to_parents(uid)
        else:
            try:
                c_uids = self.converter.map(uid)
            except PointNotFound:
                pass
        return c_uids

    def _product_rides(self, original_ride, from_uid='', to_uid=''):
        def get_stations(original_stations, uid):
            if uid.startswith('s'):
                stations = []
                for station in original_stations:
                    s_id = station['id']
                    try:
                        s_uids = self.converter.backmap(Client._full_station_id(original_ride['connector'], s_id))
                    except PointNotFound:
                        continue
                    if uid in s_uids:
                        stations.append(station)
                return stations
            elif uid.startswith('c'):
                important_stations = list(filter(lambda s: s.get('important', False), original_stations))
                if important_stations:
                    return important_stations
            return original_stations

        rides = []
        from_stations = get_stations(original_ride['pickupStops'], from_uid)
        to_stations = get_stations(original_ride['dischargeStops'], to_uid)
        for from_station in from_stations:
            for to_station in to_stations:
                from_station_copy = deepcopy(from_station)
                from_station_copy['id'] = self._full_station_id(original_ride['connector'], from_station['id'])
                to_station_copy = deepcopy(to_station)
                to_station_copy['id'] = self._full_station_id(original_ride['connector'], to_station['id'])
                patch = {
                    'from': from_station_copy,
                    'to': to_station_copy,
                    'departure': from_station['datetime'],
                    'arrival': to_station['datetime'],
                }
                rides.append(Ride.init(patch, **original_ride))
        return rides

    def psearch(self, date, passengers_count, segment):
        from_cid, to_cid = segment
        date_str = date.strftime('%Y-%m-%d') if isinstance(date, datetime.datetime) else date
        params = {
            'date': date_str,
            'from-id': from_cid,
            'to-id': to_cid,
            'passengers-count': passengers_count,
        }
        try:
            resp = self.get('rides/search', params=params)
        except PartnerError:
            logger.exception('cannot get rides from_cid=%s to_cid=%s date=%s passengers-count=%d',
                             from_cid, to_cid, date, passengers_count)
            return None
        for ride in resp['rides']:
            ride['date_str'] = date_str
            ride['from_scid'] = from_cid
            ride['to_scid'] = to_cid
        return resp

    def search(self, from_uid, to_uid, date, _=False):

        from_ids = self._map_to_settlements(from_uid)
        to_ids = self._map_to_settlements(to_uid)
        segments = product(from_ids, to_ids)
        search_results = ifilter(None, pmap(partial(self.psearch, date, 1), segments))

        rides = []
        for search_result in search_results:
            for original_ride in search_result['rides']:
                rides += self._product_rides(original_ride, from_uid, to_uid)
        return rides

    def ride(self, ride_id):
        ride_sid = ride_id['ride_sid']
        ride_id_dump = ride_id.dumps()
        original_ride = self.get('rides/{}'.format(ride_sid), monitoring_path='rides/ID')
        for ride in self._product_rides(original_ride):
            if ride['@id'] == ride_id_dump:
                return ride
        raise InvalidRide('No rides with @id={} for sid={}'.format(ride_id_dump, ride_sid))

    def ride_details(self, ride_id):
        try:
            ride_sid = ride_id['ride_sid']
            params = {
                'from-id': ride_id['from_scid'],
                'to-id': ride_id['to_scid'],
                'date': ride_id['date_str'],
            }
        except KeyError as e:
            raise InvalidIdentifier(e)

        book_params = self.get(
            'rides/{}/book-params'.format(ride_sid),
            params=params,
            monitoring_path='rides/ID/book-params',
        )
        return RideDetails.init(book_params)

    def book(self, ride_id, passengers, pay_offline):
        try:
            ride_sid = ride_id['ride_sid']
            from_ssid = ride_id['from_ssid']
            to_ssid = ride_id['to_ssid']
            from_scid = ride_id['from_scid']
            to_scid = ride_id['to_scid']
            departure = ride_id['departure']
            arrival = ride_id['arrival']
            date_str = ride_id['date_str']
        except KeyError as e:
            raise InvalidIdentifier(e)

        passengers_count = len(passengers)
        if passengers_count > 1:
            search_result = self.psearch(date_str, passengers_count, (from_scid, to_scid))

            def get_changed_ride_sid():
                for original_ride in search_result['rides']:
                    for ride in self._product_rides(original_ride):
                        if (ride['departure'] == departure and ride['arrival'] == arrival
                                and ride['from']['supplier_id'] == from_ssid and ride['to']['supplier_id'] == to_ssid):
                            return original_ride['id']
                return None

            changed_ride_sid = get_changed_ride_sid()

            if changed_ride_sid is None:
                raise RideHasNoSubstitution(
                    context='can not found correspondent ride for rideId={}, passengers_count={}'.format(
                        ride_id, passengers_count
                    )
                )
            ride_sid = changed_ride_sid

        book_params = {
            'passengers': [Book.init(
                p,
                pickupStopId=self._short_station_id(from_ssid),
                dischargeStopId=self._short_station_id(to_ssid),
            ) for p in passengers],
        }
        book_start_resp = self.post(
            'rides/{}/book'.format(ride_sid),
            json=book_params,
            polling=False,
            monitoring_path='rides/ID/book')
        book_id = book_start_resp['bookId']
        book_resp = self.get(
            'rides/{}/book'.format(ride_sid),
            params={'bookId': book_id},
            monitoring_path='rides/ID/book',
        )
        return Order.init(book_resp)

    def confirm(self, order_id):
        order_sid = order_id['order_sid']

        raw_order = self.get('orders/{}'.format(order_sid), monitoring_path='orders/ID')

        confirm_params = {
            'saleType': 'channel',
            'tickets': [{
                'paymentInfo': [{
                    'price': t['price'],
                    'paymentType': 'card'
                }],
                "ticket_id": t['id']
            } for t in raw_order['tickets']],
        }
        self.post('orders/{}/confirm'.format(order_sid), json=confirm_params, monitoring_path='orders/ID/confirm')

        raw_order = self.get('orders/{}'.format(order_sid), monitoring_path='orders/ID')
        return Order.init(raw_order, tickets_fwd={'__gen_url__': True})

    def order(self, order_id):
        order_sid = order_id['order_sid']
        raw_order = self.get('orders/{}'.format(order_sid), monitoring_path='orders/ID')
        return Order.init(raw_order, tickets_fwd={'gen_url': True})

    def refund_info(self, ticket_id):
        ticket_sid = ticket_id['ticket_sid']
        raw_refund_info = self.get('tickets/{}/calc-refund'.format(ticket_sid),
                                   monitoring_path='tickets/ID/calc-refund')
        return RefundInfo.init(raw_refund_info)

    def refund(self, ticket_id):
        ticket_sid = ticket_id['ticket_sid']
        refund_params = {
            'saleType': 'channel',
        }
        raw_refund = self.post(
            'tickets/{}/refund'.format(ticket_sid),
            json=refund_params,
            monitoring_path='tickets/ID/refund',
        )
        return Refund.init(raw_refund)

    def ticket(self, ticket_id):
        raise NotImplementedError

    def ticket_blank(self, ticket_id):
        try:
            ticket_sid = ticket_id['ticket_sid']
            order_sid = ticket_id['order_sid']
        except KeyError as e:
            raise InvalidIdentifier(e)

        raw_order = self.get('orders/{}'.format(order_sid), monitoring_path='orders/ID')
        for ticket in raw_order['tickets']:
            if ticket['id'] != ticket_sid:
                continue

            if not ticket.get('url'):
                raise InvalidTicket('No url for ticket blank: ticket={} in order={}'.format(ticket_sid, order_sid))

            ticket_resp = requests.get(ticket['url'])
            ticket_resp.raise_for_status()
            blank_binary = ticket_resp.content
            return send_file(
                io.BytesIO(blank_binary),
                mimetype='application/octet-stream',
                as_attachment=True,
                attachment_filename='ticket_{}.pdf'.format(ticket_sid)
            )
        raise InvalidTicket('No ticket={} in order={}'.format(ticket_sid, order_sid))

    def cancel(self, ticket_id):
        raise NotImplementedError

    def change_ride_endpoints(self, ride_id, pickup_sid, discharge_sid):
        raise NotImplementedError
