# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from datetime import datetime, timedelta
from itertools import chain, product

from yabus import common
from yabus.common.entities import RawSegments
from yabus.common.exceptions import OfflinePaymentDisabled, PointNotFound, PartnerError
from yabus.sks.baseclient import BaseClient
from yabus.sks.converter import point_converter
from yabus.sks.entities import Book, Endpoint, Order, RefundInfo, Ride, RideDetails, Refund
from yabus.sks.exceptions import parse_error
from yabus.util import decode_utf8
from yabus.util.parallelism import pmap, ThreadPool

logger = logging.getLogger(__name__)


class Client(BaseClient, common.Client):
    converter = point_converter

    def endpoints(self):
        return [
            Endpoint.init(decode_utf8(x))
            for x in self.get('apigetstations')['stations']
        ]

    def segments(self):
        return list(self.converter.gen_map_segments(self._raw_segments()))

    def raw_segments(self):
        return RawSegments.init({
            'segments': self._raw_segments(),
        })

    def psearch(self, from_sid, to_sid, date):
        params = {
            'from-id': from_sid,
            'to-id': to_sid,
            'date': date.strftime('%Y-%m-%d'),
        }
        try:
            rides = self.post('apisearchroute', params=params)
            return rides['routes'] if rides else []
        except PartnerError:
            logger.exception('cannot get rides from_sid=%s to_sid=%s date=%s', from_sid, to_sid, date)
            return []

    def search(self, from_uid, to_uid, date, _=False):
        def fix_ride_time(ride):
            time_fmt = '%H:%M:%S'
            departure = datetime.strptime(ride['departure'], time_fmt).time()
            ride['departure'] = datetime.combine(date, departure)
            arrival = datetime.strptime(ride['arrival'], time_fmt).time()
            ride['arrival'] = datetime.combine(date if arrival > departure else date + timedelta(days=1), arrival)
            return ride
        try:
            def func(route):
                return self.psearch(route[0], route[1], date)
            sids = map(self.converter.deprecated_map, [from_uid, to_uid])
            directions = product(*sids)
            rides = chain.from_iterable(pmap(func, directions))
            return Ride.init(map(fix_ride_time, rides), stations=self._stations)
        except PointNotFound:
            return []

    @parse_error
    def ride_details(self, ride_id):
        params = {
            'from-id': ride_id['from_sid'],
            'to-id': ride_id['to_sid'],
            'route-id': ride_id['ride_sid'],
            'date': ride_id['date'].strftime('%Y-%m-%d'),
        }

        def func(url):
            return self.post(url, params=params)
        ticket_types, free_seats = pmap(func, ['apigetprice', 'apifreeplacenumber'])
        return RideDetails.init(dict(ticket_types, **free_seats))

    def _get_neighbour_seats(self, free_seats, required_number):
        correct_even_position = None
        correct_odd_position = None

        for first_place_index in range(len(free_seats) - required_number + 1):
            correct_position = True
            for offset in range(required_number - 1):
                if free_seats[first_place_index + offset] + 1 != free_seats[first_place_index + offset + 1]:
                    correct_position = False
            if correct_position:
                if free_seats[first_place_index] % 2 == 1 and correct_odd_position is None:
                    correct_odd_position = first_place_index
                if free_seats[first_place_index] % 2 == 0 and correct_even_position is None:
                    correct_even_position = first_place_index

        if correct_odd_position is None:
            return correct_even_position

        return correct_odd_position

    def _get_free_seats(self, params, selected_seats, is_only_numbers):
        data = self.post('apifreeplacenumber', params=params)
        free_seats = []

        for a in data['freeSeats']:
            free_seats.append(a['number'])

        for seat in selected_seats:
            if seat in free_seats:
                free_seats.remove(seat)

        free_seats_as_numbers = []
        for seat in free_seats:
            try:
                free_seats_as_numbers.append(int(seat))
            except ValueError:
                return free_seats, False

        free_seats_as_numbers.sort()
        return free_seats_as_numbers, is_only_numbers

    def _find_additional_seats(self, free_seats, selected_seats, required_number, is_only_numbers):
        if not is_only_numbers:
            return free_seats[:required_number - len(selected_seats)]

        selected_free_seats = []
        if len(selected_seats) == 0:
            while len(selected_free_seats) != required_number:
                for number in range(required_number - len(selected_free_seats), 0, -1):
                    first_position = self._get_neighbour_seats(free_seats, number)
                    if first_position is None:
                        continue
                    for index in range(number):
                        selected_free_seats.append(free_seats[first_position + index])
                    for index in range(number):
                        free_seats.remove(selected_free_seats[-1 - index])
                    break
        else:
            while len(selected_seats) + len(selected_free_seats) != required_number:
                nearest = free_seats[0]
                for seat in free_seats:
                    nearest_distance = None
                    current_distance = None
                    for selected_seat in selected_seats:
                        nearest_distance = abs(nearest - int(selected_seat)) if nearest_distance is None else \
                            min(nearest_distance, abs(nearest - int(selected_seat)))
                        current_distance = abs(seat - int(selected_seat)) if current_distance is None else \
                            min(current_distance, abs(seat - int(selected_seat)))
                    if current_distance < nearest_distance:
                        nearest = seat
                selected_free_seats.append(nearest)
                free_seats.remove(nearest)

        return selected_free_seats

    def _get_selected_seats(self, passengers):
        selected_seats = []
        is_only_numbers = True
        for passenger in passengers:
            if passenger['seatCode']:
                selected_seats.append(passenger['seatCode'])
                try:
                    int(passenger['seatCode'])
                except ValueError:
                    is_only_numbers = False

        return selected_seats, is_only_numbers

    @parse_error
    def book(self, ride_id, passengers, pay_offline):
        if pay_offline:
            raise OfflinePaymentDisabled
        params = {
            'from-id': ride_id['from_sid'],
            'to-id': ride_id['to_sid'],
            'route-id': ride_id['ride_sid'],
            'date': ride_id['date'].strftime('%Y-%m-%d'),
        }

        selected_seats, is_only_numbers = self._get_selected_seats(passengers)

        # Если у sks добавится багаж, то для него не проставлять отдельное место
        if len(selected_seats) != len(passengers):
            free_seats, is_only_numbers = self._get_free_seats(params, selected_seats, is_only_numbers)
            additional_seats = self._find_additional_seats(free_seats, selected_seats, len(passengers), is_only_numbers)

            index = 0
            for passenger in passengers:
                if not passenger['seatCode']:
                    passenger['seatCode'] = str(additional_seats[index])
                    index += 1

        payload = {
            'Passengers': [Book.init(x) for x in passengers]
        }
        data = self.post('apibook', json=payload, params=params)
        order = data['order']
        tickets = data['tickets']
        return Order.init(dict(tickets=tickets, **order), tickets_fwd={'status': order['status']})

    @parse_error
    def confirm(self, order_id):
        data = self.post('apiconfirm', params={'order-id': order_id['order_sid']})
        order = data['order']
        tickets = data['tickets']
        return Order.init(dict(tickets=tickets, **order), tickets_fwd={'status': order['status']})

    @parse_error
    def refund_info(self, ticket_id):
        return RefundInfo.init({})

    @parse_error
    def refund(self, ticket_id):
        params = {'order-id': ticket_id['order_sid'], 'tickets[{}]'.format(ticket_id['idx']): ticket_id['ticket_sid']}
        data = self.post('apicancelorder', params=params)
        return Refund.init(dict(data['order']['tickets'][0]))

    def _raw_segments(self):
        stations = (x['id'] for x in self.get('apigetstations')['stations'])
        with ThreadPool(size=10) as pool:
            with self.session() as session:
                return list(chain.from_iterable(pool.map(
                    lambda x: ((x, y['id']) for y in self._routes_from(session, x)),
                    stations,
                )))

    def _routes_from(self, session, station_id):
        try:
            return session.get('apigetroutesfrom', params={
                'stationid': station_id,
            })['stations']
        except PartnerError:
            logging.exception('cannot get segments from station_id=%s', station_id)
            return []
