# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import io
import logging
import typing
from datetime import datetime
from itertools import chain, product
from zipfile import ZipFile

from dateutil import relativedelta
from dateutil.parser import parse as date_parse
from flask import send_file

from yabus import common

from yabus.common.exceptions import PointNotFound, InvalidRide, PartnerError, BookingError, RefundError, InvalidTicket, \
    InvalidTicketBlank
from yabus.util.parallelism import ThreadPool, pmap
from yabus.common.entities.raw_segments import RawSegments

from yabus.busfor import defaults
from yabus.busfor.baseclient import BaseClient
from yabus.busfor.converter import point_converter
from yabus.busfor.defaults import SUPPLIER_CODE

from yabus.busfor.entities import Book, Endpoint, Order, Ride, RideDetails, Refund, RefundInfo

logger = logging.getLogger(__name__)


class Client(common.Client, BaseClient):
    converter = point_converter

    def endpoints(self):
        return tuple(
            chain.from_iterable(pmap(lambda func: func(), [self._locations_endpoints, self._points_endpoints]))
        )

    def segments(self):
        raw_segments = self._raw_segments()
        return list(self.converter.gen_map_segments(raw_segments))

    def raw_segments(self):
        return RawSegments.init({
            'settlements_only': True,
            'segments': self._raw_segments(),
        })

    def psearch(self, from_sid, to_sid, date):
        date_str = date.strftime('%Y-%m-%d')
        params = {
            'search_mode': 'direct',
            'date': date_str,
            'from_id': from_sid[1:],
            'to_id': to_sid[1:],
        }
        try:
            rides = list(self._compose_rides(
                self._fetch_pages('search/v2/trips', defaults.SEARCH_LIMIT, params),
                date_str
            ))
        except PartnerError:
            logger.exception('cannot get rides from_sid=%s to_sid=%s date=%s', from_sid, to_sid, date)
            return []
        return rides

    def search(self, from_uid, to_uid, date, _=False):

        def func(route):
            from_sid, to_sid = route
            rides = self.psearch(from_sid, to_sid, date)
            return common.respfilters.coherence(SUPPLIER_CODE, from_uid, to_uid, from_sid, to_sid, date, rides)

        try:
            from_ids = self.converter.map(from_uid) | self.converter.map_to_parents(from_uid)
            to_ids = self.converter.map(to_uid) | self.converter.map_to_parents(to_uid)
        except PointNotFound:
            return []

        from_ids = filter(lambda i: i[0] == Endpoint.LOCATION_PREFIX, from_ids)
        to_ids = filter(lambda i: i[0] == Endpoint.LOCATION_PREFIX, to_ids)
        directions = product(from_ids, to_ids)

        return list(chain.from_iterable(pmap(func, directions)))

    def ride(self, ride):
        raise NotImplementedError

    def ride_details(self, ride_id):
        trip_id, trip_date = ride_id["ride_sid"], ride_id["ride_date"]
        trip, seats = pmap(
            lambda func: func(),
            [
                lambda: self.get(
                    "search/v2/trips/{}".format(trip_id),
                    params={"date": trip_date},
                    monitoring_path="search/v2/trips/ID",
                ),
                lambda: self.get(
                    "search/v2/trips/{}/seats".format(trip_id),
                    params={"date": trip_date},
                    monitoring_path="search/v2/trips/ID/seats",
                ),
            ],
        )
        segments, maps_seat = trip.get("segments", []), seats.get("maps_seat", [])
        if len(segments) != 1 or len(maps_seat) > 1:
            raise InvalidRide(context="got empty or multi-segment ride from partner by ride_id = {}".format(ride_id))
        return RideDetails.init(
            {"segment": segments[0], "map_seat": maps_seat[0] if maps_seat else None,
             "points": {p["id"]: p for p in trip["points"]}})

    def book(self, ride_id, passengers, pay_offline):
        trip_id, trip_date = ride_id["ride_sid"], ride_id["ride_date"]
        lock_seats_params = self._get_seats_data(passengers, trip_id, trip_date)
        try:
            lock_res = self.post('sale/v2/lockseats', json=lock_seats_params)
        except PartnerError:
            logger.exception('error lock seats for ride: %s', ride_id)
            raise

        lockseat_id = lock_res['lockseats_id']
        book_params = {
            'lockseats_id': lockseat_id,
            'mail': passengers[0]['email'],
            'nosendticket': True,
            'phone': passengers[0]['phone'],
            'passengers': [Book.init(passenger, passenger_number=i) for i, passenger in enumerate(passengers)],
        }
        raw_order = self.post('sale/v2/booking', json=book_params)
        tickets = raw_order['tickets']
        return Order.init(raw_order, order_id=tickets[0]['order_id'], status=common.STATUS_TO_ID['booked'])

    def confirm(self, order_id):
        params = {
            'order_id': order_id['order_sid'],
            'payment_method': 3   # NON_CASH, busfor recommendation
        }
        raw_order = self.put('sale/v2/booking/buyout', json=params)
        if not raw_order['approval']:
            raise BookingError(context='booking error for order {}, approval is False'.format(order_id))
        tickets = raw_order['tickets']
        return Order.init(
            raw_order,
            order_id=tickets[0]['order_id'],
            tickets_fwd={'__gen_url__': True},
        )

    def order(self, order_id):
        params = {
            'orderId': order_id['order_sid']
        }
        raw_order = self.get('sale/v2/order', params=params)
        raw_order['tickets'] = raw_order['newtickets']
        return Order.init(
            raw_order,
            order_id=raw_order['tickets'][0]['order_id'],
            tickets_fwd={
                '__gen_url__': True,
            }
        )

    def refund_info(self, ticket_id):
        order_sid, ticket_sid = ticket_id['order_sid'], ticket_id['ticket_sid']
        raw_order = self.get('sale/v2/order', params={'orderId': order_sid})
        tickets = raw_order['newtickets']
        for ticket in tickets:
            if ticket['ticket_id'] == ticket_sid:
                return RefundInfo.init(ticket)
        raise InvalidTicket(context='refund_info: cannot find ticket_id {} in order {}'.format(ticket_sid, order_sid))

    def refund(self, ticket_id):
        order_sid, ticket_sid = ticket_id['order_sid'], ticket_id['ticket_sid']
        params = {
            "order_id": order_sid,
            "positions": [{"position": ticket_sid}]
        }
        refund = self.put('sale/v2/order/return', json=params)
        refunds = refund.get('return_positions', [])
        if len(refunds) != 1:
            raise RefundError(context='unexpected value in refund, ticket {}, order {}'.format(ticket_sid, order_sid))
        refund = refunds[0]
        if not refund.get('confirmation', False):
            raise RefundError(context='refund error, ticket {}, order {}, msg: {}'.format(ticket_sid, order_sid,
                                                                                  refund['failures']))
        return Refund.init(refund)

    def ticket(self, ticket_id):
        raise NotImplementedError

    def ticket_blank(self, ticket_id):
        order_sid, ticket_sid = ticket_id['order_sid'], ticket_id['ticket_sid']
        params = {
            'order_id': order_sid,
            'service_id': ticket_sid,
        }
        res = self.get('sale/v2/order/ticket/stream', params=params, raw=True)
        with ZipFile(io.BytesIO(res)) as zip_ticket:
            file_list = zip_ticket.namelist()
            if len(file_list) != 1:
                raise InvalidTicketBlank(context='unexpected file number in ticket archive, need 1, got {}'.format(
                    len(file_list)))
            data = zip_ticket.read(file_list.pop())
        return send_file(
            io.BytesIO(data),
            mimetype='application/octet-stream',
            as_attachment=True,
            attachment_filename='ticket_{}.pdf'.format(ticket_sid)
        )

    def cancel(self, ticket_id):
        raise NotImplementedError

    def change_ride_endpoints(self, ride_id, pickup_sid, discharge_sid):
        raise NotImplementedError

    def _locations_endpoints(self):
        # type: () -> typing.Iterable[typing.Dict]
        for locations_page in self._fetch_pages("geo/v2/locations", defaults.LOCATIONS_LIMIT):
            dictionary = locations_page.get("dictionaries", {}).get(defaults.LANGUAGE, {})
            for endpoint in Endpoint.init(
                locations_page["locations"],
                is_location=True,
                parents={parent["id"]: parent for parent in locations_page["parent_locations"]},
                subtypes={subtype["id"]: subtype for subtype in dictionary.get("location_subtypes", ())},
            ):
                yield endpoint

    def _points_endpoints(self):
        # type: () -> typing.Iterable[typing.Dict]
        for points_page in self._fetch_pages("geo/v2/points", defaults.POINTS_LIMIT):
            for endpoint in Endpoint.init(points_page["points"]):
                yield endpoint

    def _raw_segments(self):
        pairs = self.get("search/v2/pairs").get("pairs")
        return [('{}{}'.format(Endpoint.LOCATION_PREFIX, pair["location_id_from"]),
                '{}{}'.format(Endpoint.LOCATION_PREFIX, to))
                for pair in pairs for to in pair["location_id_to"]
                ]

    def _fetch_pages(self, path, page_limit, params=None):
        # type: (typing.Text, int, typing.Dict) -> typing.Iterable
        params = params or {}
        data = self.get(path, params=dict(params, **{'limit': page_limit}))
        if not data:
            return
        pages_info = data["pages_info"]
        logger.debug("fetching %s pages %s", path, pages_info["page_count"])

        yield data

        with ThreadPool(size=5) as pool:
            for page in pool.imap_unordered(lambda n: self.get(
                    "{}/page".format(path), params={"page_uuid": pages_info["page_uuid"], "number_page": n},
                ),
                range(2, pages_info["page_count"] + 1),
            ):
                yield page

    def _compose_rides(self, pages, search_date_str):
        for page in pages:
            vehicles = {v['id']: v for v in page['vehicles']}
            points = {v['id']: v for v in page['points']}
            carriers = {v['id']: v for v in page['carriers']}
            for raw_ride in page['segments']:
                yield Ride.init(raw_ride, vehicles=vehicles, points=points, carriers=carriers,
                                search_date_str=search_date_str)

    def _get_seats_data(self, passengers, trip_id, trip_date):
        segment_seats = []
        passengers_info = []
        trip_dtm = datetime.strptime(trip_date, '%Y-%m-%d')
        for num, passenger in enumerate(passengers):
            if passenger['seatCode'] and passenger['seatCode'] != '0':
                segment_seats.append({
                    'passenger_num': num,
                    'seat_id': passenger['seatCode'],
                    'tariff_code': passenger['ticketTypeCode'],
                })
            passengers_info.append({
                'passenger_age': relativedelta.relativedelta(trip_dtm, date_parse(passenger['birthDate']).
                                                             replace(tzinfo=None)).years,
                'passenger_num': num,
            })
        return {
            'passengers': passengers_info,  # without passengers age child tariffs is replaced with base tariff
            'passenger_count': len(passengers),
            'trip_id': trip_id,
            'trip_date': trip_date,
            'seats': [{
                "segment_id": trip_id,
                "segment_seats": segment_seats,
            }],
        }
