# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import json
import io
from flask import send_file
from itertools import chain, product
from logging import getLogger

from yabus import common
from yabus.common.entities import RawSegments
from yabus.common.exceptions import OfflinePaymentDisabled, PointNotFound, InvalidIdentifier
from yabus.ok.baseclient import BaseClient
from yabus.ok.converter import point_converter
from yabus.ok.entities import Book, Endpoint, Order, Refund, RefundInfo, Ride, RideDetails, Ticket
from yabus.ok.exceptions import parse_error
from yabus.util import decode_utf8, deduplicate
from yabus.util.bottleneck import bottleneck
from yabus.util.parallelism import ThreadPool, pmap

logger = getLogger(__name__)


class Client(common.Client, BaseClient):
    converter = point_converter

    def endpoints(self):
        with ThreadPool(size=30) as pool, self.session() as session:
            endpoints = self._endpoints(pool, session)
        return [
            Endpoint.init(decode_utf8(x))
            for x in endpoints
        ]

    def raw_segments(self):
        return RawSegments.init({
            'segments': self._raw_segments(),
        })

    def segments(self):
        return list(self.converter.gen_map_segments(self._raw_segments()))

    def psearch(self, from_sid, to_sid, date, session):
        start_key = 'city_id_start'
        if from_sid.startswith('s'):
            start_key = 'station_id_start'
        end_key = 'city_id_end'
        if to_sid.startswith('s'):
            end_key = 'station_id_end'
        payload = {
            start_key: from_sid[1:],
            end_key: to_sid[1:],
            'date': date.strftime('%Y-%m-%d'),
        }
        try:
            response = session.get('ride/list', payload)
            page_count = response.get('_ride_list_meta', {}).get('total_pages')
            rides = response.get('ride_list', [])
            if page_count is not None and int(page_count) > 1:
                rides.extend(chain.from_iterable(
                    pmap(lambda page: session.get('ride/list', dict(payload, page=page)).get('ride_list', []),
                         range(2, page_count + 1))))
            return Ride.init(rides)
        except Exception:
            logger.exception('error getting rides')
            return []

    def search(self, from_uid, to_uid, date, _=False):
        try:
            with self.session() as session:
                def func(route):
                    return self.psearch(route[0], route[1], date, session)
                sids = map(self.converter.map, [from_uid, to_uid])
                #  only station-station or city-city searches allowed
                directions = ((x, y) for x, y in product(*sids) if x[0] == y[0])
                return list(chain.from_iterable(pmap(func, directions)))
        except PointNotFound:
            return []

    @parse_error('ride_details')
    def ride_details(self, ride_id):
        payload = {
            'ride_segment_id': ride_id['ride_sid'],
        }
        methods = [
            'ride/position/free',
            'card_identity/list',
            'ride',
        ]
        response = pmap(lambda x: self.get(x, payload), methods)
        response = {k: v for x in response for k, v in x.items()}
        return RideDetails.init(response, citizenships=self.countries)

    @parse_error('book')
    def book(self, ride_id, passengers, pay_offline):
        if pay_offline:
            raise OfflinePaymentDisabled
        payload = {
            'ride_segment_id': ride_id['ride_sid'],
            'ticket_data': json.dumps([Book.init(x) for x in passengers]),
        }
        response = self.post('operation/booking/tmp', payload)
        return Order.init(response['operation'], tickets_fwd={
            '__countries__': self.countries,
        })

    @parse_error('confirm')
    def confirm(self, order_id):
        payload = {
            'operation_id': order_id['order_sid'],
            'operation_renew': 0,
        }
        with self.session() as session:
            response = session.post('operation/buy', payload)
            return Order.init(response['operation'], tickets_fwd={
                '__countries__': self.countries,
                '__gen_url__': True,
            })

    @parse_error('order')
    def order(self, order_id):
        payload = {
            'operation_id': order_id['order_sid'],
        }
        with self.session() as session:
            response = session.get('operation/', payload)
            return Order.init(response['operation'], tickets_fwd={
                '__countries__': self.countries,
                '__gen_url__': True,
            })

    @parse_error('refund')
    def refund_info(self, ticket_id):
        payload = {
            'operation_id': ticket_id['order_sid'],
            'ticket_id': ticket_id['ticket_sid'],
        }
        response = self.get('ticket/refund/calc', payload)
        return RefundInfo.init(response.get('ticket_refund', {}))

    @parse_error('refund')
    def refund(self, ticket_id):
        payload = {
            'operation_id': ticket_id['order_sid'],
            'ticket_id': ticket_id['ticket_sid'],
        }
        response = self.post('ticket/refund', payload)
        return Refund.init(response['ticket_refund'])

    @parse_error('ticket')
    def ticket(self, ticket_id):
        payload = {
            'ticket_id': ticket_id['ticket_sid'],
        }
        response = self.get('ticket', payload)
        return Ticket.init(
            response['ticket'],
            __countries__=self.countries,
            __gen_url__=True,
        )

    @parse_error('ticket_blank')
    def ticket_blank(self, ticket_id):
        try:
            ticket_sid = ticket_id['ticket_sid']
            order_sid = ticket_id['order_sid']
        except KeyError as e:
            raise InvalidIdentifier(e)

        operation = self.get('operation', {
            'operation_id': order_sid,
        })

        ticket_resp = self.get(
            'ticket/pdf',
            params={'ticket_id': ticket_sid, 'operation_hash': operation['operation']['hash']},
            raw=True,
        )
        ticket_resp.raise_for_status()
        blank_binary = ticket_resp.content
        return send_file(
            io.BytesIO(blank_binary),
            mimetype='application/octet-stream',
            as_attachment=True,
            attachment_filename='ticket_{}.pdf'.format(ticket_sid)
        )

    def _raw_segments(self):
        with ThreadPool(size=10) as pool:
            with self.session() as session:
                def func(type_):
                    dispatches = session.get('{}/list/from'.format(type_), params={})['{}_list'.format(type_)]
                    prefix = type_[0]
                    id_key = '{}_id'.format(type_)
                    return chain.from_iterable(pool.map(
                        lambda x: (
                            (prefix + x[id_key], prefix + y[id_key])
                            for y in session.get('{}/list/to'.format(type_), params={
                            '{}_id_start'.format(type_): x[id_key]
                        })['{}_list'.format(type_)]
                        ),
                        dispatches,
                    ))
                return list(chain.from_iterable(pool.map(func, ['station', 'city'])))

    def _endpoints(self, pool, session):
        city_endpoints = list(deduplicate(chain.from_iterable(
            session.get('city/list/{}'.format(x), {})['city_list']
            for x in ['from', 'to']
        ), key=lambda x: x['city_id']))

        city_titles = [city['city_title'] for city in city_endpoints]

        station_endpoints = deduplicate(chain.from_iterable(pool.map(
            self._get_stations, ((city, direction, session) for city in city_titles for direction in ['from', 'to'])
        )), key=lambda x: x['station_id'])

        endpoints_raw = (('city', city_endpoints), ('station', station_endpoints))
        endpoints = chain.from_iterable(
            (
                dict(
                    endpoint,
                    **{
                        'id': point_type[0] + endpoint['{}_id'.format(point_type)],
                        'title': endpoint['{}_title'.format(point_type)],
                        'country_code': endpoint.get('country_iso3166') or endpoint.get('country_iso'),
                        'type': point_type,
                    }
                ) for endpoint in endpoint_list
            ) for point_type, endpoint_list in endpoints_raw
         )

        return endpoints

    @bottleneck(calls=100)  # rps
    def _get_stations(self, (city, direction, session)):
        return session.get('station/list/{}'.format(direction), {'query': city})['station_list']
