# -*- coding: utf-8 -*-
from __future__ import unicode_literals

from datetime import date, datetime
from os import environ

from dateutil.relativedelta import relativedelta
from dateutil.tz import UTC
from psycopg2.extras import RealDictCursor
from six import text_type

from travel.cpa.collectors.buses.rates import BusesRates
from travel.cpa.collectors.lib.collector import Collector
from travel.cpa.lib.common import with_retries
from travel.cpa.lib.lib_datetime import parse_datetime_iso, timestamp
from travel.cpa.lib.lib_logging import get_logger
from travel.cpa.lib.order_snapshot import OrderCurrencyCode, OrderStatus

LOG = get_logger(__name__)


class BusesCollector(Collector):
    SNAPSHOT_CLS = None

    STATUS_MAPPING = {
        'booked': OrderStatus.UNPAID,
        'ordered': OrderStatus.UNPAID,
        'purchase_error': OrderStatus.UNPAID,
        'purchased': OrderStatus.PAID,
        'confirmed': OrderStatus.CONFIRMED,
        'canceled': OrderStatus.CANCELLED,
    }

    TICKET_TYPE_MAPPER = {
        1: 'adult_passengers_count',
        2: 'children_with_seats_count',
        3: 'baggage_tickets_count',
    }

    LABELS_MAPPER = {
        'query': {
            'req_id': 'label_req_id',
            'utm_source': 'label_source',
            'utm_medium': 'label_medium',
            'utm_campaign': 'label_campaign',
            'utm_term': 'label_term',
            'utm_content': 'label_content',
            'gclid': 'label_gclid',
            'partner': 'label_partner',
            'subpartner': 'label_subpartner',
            'partner_uid': 'label_partner_uid',
            'ytp_referer': 'label_ytp_referer',
        },
        'platform': 'label_user_device',
        'uid': 'label_passport_uid',
        'cookies': {
            'i': 'label_icookie',
            'yandexuid': 'label_yandex_uid'
        },
    }

    RAW_ORDER_FIELDS = (
        'uid',
        'bus_internal_status',
        'billing_info',
        'supplier_order_id',
        'search_from_id',
        'search_to_id',
        'label_platform'
    )

    ALLOWED_ENVS = ('production', 'testing')

    def __init__(self, options):
        super(BusesCollector, self).__init__()

        self.updated_from = timestamp(datetime.combine(parse_datetime_iso(options.updated_from).date(), datetime.min.time()))
        self.updated_to = timestamp(datetime.combine(parse_datetime_iso(options.updated_to).date(), datetime.max.time()))

        self.buses_environment = options.buses_environment
        if self.buses_environment not in self.ALLOWED_ENVS:
            self.buses_environment = 'testing'

        environ['YENV_TYPE'] = self.buses_environment
        if options.vault_token is None:
            raise ValueError('Token option "--vault-token" not found')
        environ['RASP_VAULT_OAUTH_TOKEN'] = options.vault_token
        self._load_orders_with_retries = with_retries(
            self._load_orders,
            counter=self.metrics,
            key='collector.events.invalid_response',
        )

        self.bus_rates = BusesRates()

    @classmethod
    def configure(cls, parser):
        parser.add_argument('--updated-from', default=(date.today() - relativedelta(days=2)).isoformat())
        parser.add_argument('--updated-to', default=date.today().isoformat())
        parser.add_argument('--buses-environment', default='production')

    def _get_snapshots(self):
        raw_orders = self._load_orders_with_retries()
        for raw_order in raw_orders:
            yield self._make_snapshot(raw_order)

    def _load_orders(self):
        from travel.rasp.bus.db import session_scope
        with session_scope() as session:
            with session.connection().connection.cursor(cursor_factory=RealDictCursor) as cursor:
                cursor.execute('''SELECT public.orders.id as uid, public.orders.status as bus_internal_status, public.orders.ride, public.orders.booking,
                    public.orders.creation_ts AT TIME ZONE 'UTC' as creation_ts, public.orders.billing_log, public.orders.billing_log::TEXT as billing_info,
                    public.orders.ride->> 'onlineRefund' as online_refund,
                    public.orders.booking ->> 'supplierId' as supplier_order_id,
                    public.orders.ride ->> 'arrival' as arrival_dt,
                    public.orders.ride ->> 'departure' as departure_dt,
                    public.orders.ride ->> 'fromSearch' as search_from_id,
                    public.orders.ride ->> 'toSearch' as search_to_id,
                    public.orders.booking ->> 'source' as label_platform,
                    public.orders.tracking
                    FROM orders
                    WHERE id in (
                        SELECT id
                        FROM orders_log
                        WHERE timestamp BETWEEN %s::bigint * 1000 and %s::bigint * 1000
                    )
                    AND orders.booking->>'partner' like %s;''',
                               (self.updated_from, self.updated_to, self.PARTNER_NAME))
                return cursor.fetchall()

    def _make_snapshot(self, raw_order):
        order = self._get_order_from_raw(raw_order)

        snapshot = self.SNAPSHOT_CLS.from_dict(order, ignore_unknown=True)
        snapshot.update_partner_order_id(text_type(order['supplier_order_id']))

        return snapshot

    def _get_order_from_raw(self, raw_order):
        order = {}
        for field in self.RAW_ORDER_FIELDS:
            order[field] = raw_order[field]

        order['status'] = self.STATUS_MAPPING[raw_order['bus_internal_status']]
        order['online_refund'] = bool(raw_order['online_refund'])
        self._calc_order_billing(raw_order, order)
        self._count_order_tickets(raw_order, order)
        self._extract_order_labels(raw_order, order)
        order['departure'] = timestamp(parse_datetime_iso(raw_order['departure_dt']))
        if raw_order['arrival_dt']:
            order['arrival'] = timestamp(parse_datetime_iso(raw_order['arrival_dt']))
        ride = raw_order['ride']
        order['from_id'] = ride['fromDetails'].get('raspId')
        order['to_id'] = ride['toDetails'].get('raspId')

        order['bus_model'] = ride.get('bus')
        order['route_name'] = ride.get('name')

        order['from_partner_description'] = ride['fromDescription']
        order['to_partner_description'] = ride['toDescription']
        order['carrier_id'] = ride.get('carrierID')

        order['currency_code'] = OrderCurrencyCode.RUB
        order['created_at'] = timestamp(raw_order['creation_ts'])
        order['updated_at'] = timestamp(self.now)

        return order

    def _calc_order_billing(self, raw_order, order):
        order.update({
            'finished_at': None,
            'order_amount': .0,
            'profit_amount': .0,
            'total_agency_fee_amount': .0,
            'total_fee_amount': .0,
            'total_partner_fee_amount': .0,
            'total_partner_refund_fee_amount': .0,
            'total_refund_fee_amount': .0,
            'total_refund_ticket_amount': .0,
            'total_tariff_amount': .0,
            'partner_commission': .0,
        })
        if not raw_order['billing_log']:
            return

        refund_service_order_ids = {}
        for billing_event in raw_order['billing_log']:
            if billing_event['type'] == "payment":
                order['finished_at'] = timestamp(parse_datetime_iso(billing_event['timestamp']).astimezone(UTC))
                order['order_amount'] = float(billing_event['amount'])

                for order_record in billing_event['orders']:
                    record_amount = float(order_record['amount'])
                    if order_record['developer_payload'].endswith('-yandex-fee'):
                        order['total_fee_amount'] += record_amount
                        refund_service_order_ids[order_record['service_order_id']] = 'total_refund_fee_amount'
                        continue
                    if order_record['developer_payload'].endswith('-fee'):
                        order['total_partner_fee_amount'] += record_amount
                        refund_service_order_ids[order_record['service_order_id']] = 'total_partner_refund_fee_amount'
                        order['total_agency_fee_amount'] +=\
                            max(0.01, round(self._get_order_agency_rate(raw_order) * record_amount, 2))
                        continue
                    if order_record['developer_payload'].endswith('-ticket'):
                        order['total_tariff_amount'] += record_amount
                        refund_service_order_ids[order_record['service_order_id']] = 'total_refund_ticket_amount'
                        order['total_agency_fee_amount'] +=\
                            max(0.01, round(self._get_order_agency_rate(raw_order) * record_amount, 2))

                order['profit_amount'] = order['total_fee_amount'] + order['total_agency_fee_amount']

            if billing_event['type'] == "refund":
                refund_records = billing_event.get('orders')
                if not refund_records:
                    continue
                for refund_record in refund_records:
                    if refund_record['service_order_id'] in refund_service_order_ids:
                        order[refund_service_order_ids[refund_record['service_order_id']]] += float(
                            refund_record['amount'])

    def _count_order_tickets(self, raw_order, order):
        order.update({
            'adult_passengers_count': 0,
            'children_with_seats_count': 0,
            'baggage_tickets_count': 0,
            'active_ticket_count': 0,
            'refunded_ticket_count': 0,
            'total_ticket_count': len(raw_order['booking']['tickets']),
            'requested_ticket_count': len(raw_order['booking']['tickets'])
        })

        for ticket in raw_order['booking']['tickets']:
            if ticket['status']['name'] == 'returned':
                order['refunded_ticket_count'] += 1
            if ticket['status']['name'] == 'sold':
                order['active_ticket_count'] += 1

            if ticket['passenger']['ticketType']:  # saw it as 'null'
                order[self.TICKET_TYPE_MAPPER[ticket['passenger']['ticketType']['id']]] += 1

    def _extract_order_labels(self, raw_order, order):
        if not raw_order['tracking']:
            return

        for field_key, label in self.LABELS_MAPPER.items():
            upper_label_value = raw_order['tracking'].get(field_key)
            if not upper_label_value:
                continue
            if isinstance(label, dict):
                for nested_field_key, nested_label in label.items():
                    label_value = upper_label_value.get(nested_field_key)
                    if label_value:
                        if isinstance(label_value, text_type):
                            order[nested_label] = label_value
                        elif isinstance(label_value, list) and len(label_value) > 0:
                            order[nested_label] = label_value[0]
            else:
                order[label] = upper_label_value

    def _get_order_agency_rate(self, raw_order):
        order_date = raw_order['creation_ts'].date()
        return self.bus_rates.get_order_agency_rate(self.PARTNER_NAME, order_date)
