# -*- coding: utf-8 -*-
from datetime import datetime
from logging import getLogger
from typing import Iterator, List
from urlparse import urljoin

import requests

from travel.avia.ticket_daemon.ticket_daemon.api.flights import FlightFabric, Segment, Variant
from travel.avia.ticket_daemon.ticket_daemon.api.query import QueryIsNotValid
from travel.avia.ticket_daemon.ticket_daemon.daemon.utils import sleep_every, BadPartnerResponse
from travel.avia.ticket_daemon.ticket_daemon.lib.currency import Price
from travel.avia.ticket_daemon.ticket_daemon.lib.tracker import QueryTracker
from travel.avia.ticket_daemon.ticket_daemon.lib.partner_secret_storage import partner_secret_storage
from travel.avia.ticket_daemon.ticket_daemon.lib.baggage import Baggage


log = getLogger(__name__)

API_URL = 'https://ya1-api.flysmartavia.com'
SEARCH_URL = urljoin(API_URL, '/api/avia/search')
RESULTS_URL = urljoin(API_URL, '/api/avia/results')
BOOK_REDIRECT_URL = urljoin(API_URL, '/booking-meta-redirect')
SMARTAVIA_PARTNER_CODE = 'smartavia'
SMARTAVIA_SALEPOINT = partner_secret_storage.get(importer_name=SMARTAVIA_PARTNER_CODE, namespace='SALEPOINT')
SMARTAVIA_TRAFFICSOURCE = partner_secret_storage.get(importer_name=SMARTAVIA_PARTNER_CODE, namespace='TRAFFICSOURCE')
SMARTAVIA_SEARCH_ENGINE = partner_secret_storage.get(importer_name=SMARTAVIA_PARTNER_CODE, namespace='SEARCH_ENGINE')
DATETIME_FORMAT = '%Y-%m-%d %H:%M'
DATE_SEARCH_FORMAT = '%Y-%m-%d'
DATE_REDIRECT_FORMAT = '%d.%m.%Y'
LANG_MAP = {'ru': 'ru', 'en': 'en'}
UTM_CAMPAIGN = '478-h3-21'
UTM_MEDIUM = 'cpa'
UTM_SOURCE = 'travel.yandex.ru'


class DictList(object):
    def __init__(self):
        self.data = {}

    def append(self, joint_index, segment_index, value):
        self.data[(joint_index, segment_index)] = value

    def flatten(self):
        return [v for k, v in sorted(self.data.items())]


def validate_query(q):
    if q.klass != 'economy':
        raise QueryIsNotValid('Only economy class requests are allowed')


@QueryTracker.init_query
def query(tracker, q):
    json_content = get_data(tracker, q)
    variants = list(generate_variants(json_content, q))
    return variants


def build_search_params(q):
    segments = [
        {
            'departure': q.iata_from.encode('utf-8'),
            'arrival': q.iata_to.encode('utf-8'),
            'date': q.date_forward.strftime(DATE_SEARCH_FORMAT),
        },
    ]
    if q.date_backward:
        segments.append({
            'departure': q.iata_to.encode('utf-8'),
            'arrival': q.iata_from.encode('utf-8'),
            'date': q.date_backward.strftime(DATE_SEARCH_FORMAT),
        })
    search_params = {
        'salepoint': SMARTAVIA_SALEPOINT,
        'trafficsource': SMARTAVIA_TRAFFICSOURCE,
        'adt': q.passengers.get('adults', 0),
        'chd': q.passengers.get('children', 0),
        'inf': q.passengers.get('infants', 0),
        'segment': segments,
    }
    return search_params


def get_data(tracker, q):
    # First request (r1) gets us "result_id" key,
    # second one (r2) brings back real results
    r1 = tracker.wrap_request(
        requests.post,
        SEARCH_URL,
        json=build_search_params(q),
    )
    r1.raise_for_status()

    result_id = extract_result_id(r1.json())
    if not result_id:
        raise BadPartnerResponse(SMARTAVIA_PARTNER_CODE, r1)

    r2 = tracker.wrap_request(
        requests.get,
        RESULTS_URL,
        params={
            'result_id': result_id,
        },
    )
    return r2.json()


def extract_result_id(result):
    if result.get('status') != 'success':
        log.error('Non-success status from %s: %r', SMARTAVIA_PARTNER_CODE, result)
        return None

    data = result.get('data')
    if not data:
        log.warn('No data in result from %s', SMARTAVIA_PARTNER_CODE)
        return None

    return data.get('result_id')


def generate_variants(result, q):
    # type: (dict, Query) -> Iterator[Variant]
    if result.get('status') != 'success':
        log.error('Non-success status from %s in query: %r', SMARTAVIA_PARTNER_CODE, q.id)
        return

    variants_data = result.get('data')
    if not variants_data:
        log.info('No variants data from %s in query: %r', SMARTAVIA_PARTNER_CODE, q.id)
        return

    data = variants_data['variants']
    if not data:
        return

    for variant_data in sleep_every(data):
        v = Variant()
        v.klass = q.klass
        price_data = variant_data['total_price']
        currency = price_data.get('currency')
        v.tariff = Price(float(price_data['value']), currency)

        flight_data = variant_data.get('flights')
        if not flight_data:
            continue
        v.forward.segments = parse_segments(flight_data[0], q.importer.flight_fabric, v.klass)
        if not v.forward.segments:
            continue
        if len(flight_data) > 1:
            v.backward.segments = parse_segments(flight_data[1], q.importer.flight_fabric, v.klass)
            if not v.backward.segments:
                continue

        v.order_data = build_order_data(flight_data, currency, q)

        yield v


def parse_segments(segments_data, flight_fabric, klass):
    # type: (dict, FlightFabric, str) -> List[Segment]
    segments = []
    segments_elems = segments_data.get('segments')
    if not segments_elems:
        return None

    for segment in segments_elems:
        segments.append(flight_fabric.create(
            station_from_iata=segment['departure']['airport']['code'],
            station_to_iata=segment['arrival']['airport']['code'],
            local_departure=datetime.strptime(segment['departure']['datetime'], DATETIME_FORMAT),
            local_arrival=datetime.strptime(segment['arrival']['datetime'], DATETIME_FORMAT),
            company_iata=segment['airline']['code'],
            pure_number=segment['flight_number'],
            klass=klass,
            fare_code=segment['prices'][0]['fare']['base_code'],
            fare_family=segment['segment_params']['brand'],
            baggage=parse_baggage(segment['brand_options']),
        ))
    return segments


def parse_baggage(brand_options):
    weight = brand_options.get('luggage_weight') or 0
    count = brand_options.get('luggage_count') or 0
    return Baggage.from_partner(pieces=count, weight=weight)


def build_order_data(flight_data, currency, q):
    # type: (dict, str, Query) -> dict
    carriers = []
    classes = []
    origin_city_codes = DictList()
    destination_city_codes = DictList()
    flight_numbers = DictList()
    subclasses = DictList()
    dates = DictList()
    joint_ids = DictList()
    brands = DictList()
    segments_count = 0
    direct_only = True
    for joint in flight_data:
        joint_carrier = None
        joint_cabin = None

        segments = joint['segments']
        if len(segments) > 1:
            direct_only = False

        for segment_index, segment in enumerate(segments):
            segments_count += 1
            segment_params = segment['segment_params']
            if not joint_carrier:
                joint_carrier = segment_params['airline']
            if not joint_cabin:
                joint_cabin = segment['booking_subclass']['cabin']
            joint_index = segment_params['flight_index']
            origin_city_codes.append(joint_index, segment_index, segment_params['departure'])
            destination_city_codes.append(joint_index, segment_index, segment_params['arrival'])
            flight_numbers.append(joint_index, segment_index, segment_params['flight_number'])
            subclasses.append(joint_index, segment_index, segment_params['subclass'])
            departure_date = datetime.strptime(segment_params['date'], DATE_SEARCH_FORMAT)
            dates.append(joint_index, segment_index, departure_date.strftime(DATE_REDIRECT_FORMAT))
            brands.append(joint_index, segment_index, segment_params['brand'])
            joint_ids.append(joint_index, segment_index, str(joint_index+1))

        carriers.append(joint_carrier)
        classes.append(joint_cabin)

    params = {
        'actualCurrency': currency,
        'carrier': carriers,
        'flight': flight_numbers.flatten(),
        'origin-city-code': origin_city_codes.flatten(),
        'destination-city-code': destination_city_codes.flatten(),
        'class': classes,
        'subclass': subclasses.flatten(),
        'count-aaa': str(q.passengers.get('adults', 0)),
        'count-rbg': str(q.passengers.get('children', 0)),
        'count-rmg': str(q.passengers.get('infants', 0)),
        'date': dates.flatten(),
        'direct-only': '1' if direct_only else '0',
        'joint-id': joint_ids.flatten(),
        'lang': LANG_MAP.get(q.lang, 'ru'),
        'requestedBrands': brands.flatten(),
        'searchPeriod': '0',
        'segmentsCount': str(segments_count),
    }
    return {
        'params': params,
        'qid': q.id,
    }


def book(order_data):
    post_data = {}
    for key, value in order_data['params'].iteritems():
        if isinstance(value, list):
            for i in range(len(value)):
                post_data['%s[%d]' % (key, i)] = value[i]
        else:
            post_data[key] = value

    post_data['search-engine'] = SMARTAVIA_SEARCH_ENGINE
    post_data['utm_campaign'] = UTM_CAMPAIGN
    post_data['utm_medium'] = UTM_MEDIUM
    post_data['utm_source'] = UTM_SOURCE
    return {'url': BOOK_REDIRECT_URL, 'post_data': post_data}
