# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import os
import logging
import json
from datetime import timedelta, datetime, time
from typing import Dict, TextIO

from django.conf import settings
from yql.api.v1.client import YqlClient

from common.data_api.movista.instance import movista_client
from common.data_api.im.instance import im_client
from travel.rasp.library.python.common23.date import environment
from common.settings.utils import define_setting
from common.settings.configuration import Configuration

log = logging.getLogger(__name__)


define_setting(
    'YT_CPA_SUBURBAN_ORDERS_DIR',
    {Configuration.PRODUCTION: 'home/travel/prod/cpa/suburban/orders'},
    default='home/travel/testing/cpa/suburban/orders',
)
define_setting(
    'YT_BILLING_TRANSACTIONS_DIR',
    {Configuration.PRODUCTION: '//home/travel/prod/billing/transactions'},
    default='//home/travel/testing/billing/transactions',
)

define_setting('BILLING_SUBURBAN_SERVICE_ID', default=716)
define_setting('SUBURBAN_BILLING_COMPARE_RESULT_DATA_PATH', default='suburban_billing_compare')

define_setting('MINUTES_BETWEEN_BILLING_AND_PROVIDER', default=10)


class ProviderTicket(object):
    def __init__(self, provider_order_id, status, update_dt, price):
        # type: (int, str, str, float) -> None
        self.provider_order_id = provider_order_id
        self.status = status
        self.update_dt = update_dt
        self.price = price


class MovistaTicket(ProviderTicket):
    def __init__(self, ticket):
        super(MovistaTicket, self).__init__(
            provider_order_id=ticket['orderId'],
            status=ticket['status'],
            update_dt=ticket['confirmDate' if ticket['status'] == 'confirmed' else 'refundDate'],
            price=ticket['price']
        )

    def to_file_str(self):
        return json.dumps({
            'provider_order_id': self.provider_order_id,
            'status': self.status,
            'update_dt': self.update_dt,
            'price': self.price
        })


class ImTicket(ProviderTicket):
    def __init__(self, order):
        order_item = order['OrderItems'][0]
        super(ImTicket, self).__init__(
            provider_order_id=order['OrderId'],
            status=order_item['SimpleOperationStatus'],
            update_dt=order['Confirmed'],
            price=order['Amount']
        )

    def to_file_str(self):
        return json.dumps({
            'provider_order_id': self.provider_order_id,
            'status': self.status,
            'confirm_dt': self.update_dt,
            'price': self.price
        })


class BillingOrder(object):
    def __init__(self, order_row):
        provider_order_id, self.service_order_id, self.transaction_type, self.price, self.update_dt = order_row
        self.provider_order_id = int(provider_order_id) if provider_order_id else None

    def to_file_str(self):
        return json.dumps({
            'provider_order_id': self.provider_order_id,
            'service_order_id': self.service_order_id,
            'transaction_type': self.transaction_type,
            'update_dt': self.update_dt,
            'price': self.price
        })


class DifferentPriceOrder(object):
    def __init__(self, order_id, billing_price, provider_price, provider_update_dt):
        self.order_id = order_id
        self.billing_price = billing_price
        self.provider_price = provider_price
        self.provider_update_dt = provider_update_dt

    def to_file_str(self):
        return json.dumps({
            'order_id': self.order_id,
            'provider_update_dt': self.provider_update_dt,
            'provider_price': self.provider_price,
            'billing_price': self.billing_price
        })


def get_movista_tickets(start_period_dt, end_period_dt):
    # type: (datetime, datetime) -> Dict[int, MovistaTicket]

    start_dt = start_period_dt - timedelta(minutes=settings.MINUTES_BETWEEN_BILLING_AND_PROVIDER)
    response = movista_client.report(start_dt, end_period_dt, status='confirmed')
    return {ticket['orderId']: MovistaTicket(ticket) for ticket in response['tickets']}


def get_im_tickets(start_period_dt, end_period_dt):
    # type: (datetime, datetime) -> Dict[int, ImTicket]

    orders_date = start_period_dt - timedelta(days=1)
    im_tickets = {}
    while orders_date < end_period_dt:
        response = im_client.order_list(orders_date.date())
        im_tickets.update({
            order['OrderId']: ImTicket(order)
            for order in response['Orders']
            if order['OrderItems'][0]['SimpleOperationStatus'] == 'Succeeded'
        })
        orders_date = orders_date + timedelta(days=1)

    return im_tickets


def get_billing_orders(provider, start_period_dt, end_period_dt):
    # type: (str, datetime, datetime) -> Dict[int, BillingOrder]

    # В CPA дату отправления пишут как время в UTC, да еще иногда дата может попасть на предыдущий день,
    # поэтому в условии пишем на два дня раньше
    departure_date_bound = start_period_dt - timedelta(days=2)
    end_dt = end_period_dt + timedelta(minutes=settings.MINUTES_BETWEEN_BILLING_AND_PROVIDER)

    query = '''
        $provider_id = (
            SELECT substring(partner_order_id, 0, 17) as inner_id, provider_order_id
            FROM hahn.`{cpa_dir}`
            WHERE departure_date >= '{departure_date_bound}'
                  and provider = '{provider}'
                  and provider_order_id is not null
        );

        SELECT provider.provider_order_id, service_order_id, transaction_type, price, update_dt
        FROM RANGE(`{billing_dir}`, '{start_period}', '{end_period}') as log
        JOIN $provider_id as provider ON log.service_order_id = provider.inner_id
        WHERE service_id = {service_id}
              and transaction_type = 'payment'
              and payment_type = 'cost'
              and update_dt < '{end_dt}'
    '''.format(
        cpa_dir=settings.YT_CPA_SUBURBAN_ORDERS_DIR,
        billing_dir=settings.YT_BILLING_TRANSACTIONS_DIR,
        departure_date_bound=departure_date_bound.strftime('%Y-%m-%d'),
        provider=provider.upper(),
        start_period=start_period_dt.strftime('%Y-%m-%d'),
        end_period=end_period_dt.strftime('%Y-%m-%d'),
        end_dt=end_dt.strftime('%Y-%m-%dT%H:%M:%S'),
        service_id=settings.BILLING_SUBURBAN_SERVICE_ID
    )

    log.info('Run YQL query: {}'.format(query))

    with YqlClient(db='hahn', token=settings.YQL_TOKEN) as yql_client:
        request = yql_client.query(query)
        request.run()

        billing_orders_by_ids = {}
        for table in request:
            for row in table.get_iterator():
                order = BillingOrder(row)
                if order.provider_order_id:
                    billing_orders_by_ids[order.provider_order_id] = BillingOrder(row)

    return billing_orders_by_ids


class OrdersCheckResult(object):
    def __init__(
        self, provider, start_period_dt, end_period_dt, provider_tickets_count, provider_total_price,
        missed_provider_tickets, billing_orders_count, billing_total_price, missed_billing_orders,
        different_price_orders
    ):
        self.provider = provider
        self.start_date = start_period_dt.date()
        self.end_date = end_period_dt.date() - timedelta(days=1)
        self.provider_tickets_count = provider_tickets_count
        self.provider_total_price = provider_total_price
        self.missed_provider_tickets = missed_provider_tickets
        self.billing_orders_count = billing_orders_count
        self.billing_total_price = billing_total_price
        self.missed_billing_orders = missed_billing_orders
        self.different_price_orders = different_price_orders

    def save_to_file(self, result_file):
        # type: (TextIO) -> None

        result_file.write('{} orders for {} - {}\n'.format(self.provider, self.start_date, self.end_date))

        result_file.write('\n')
        result_file.write('Total price of {} tickets is {}\n'.format(self.provider, self.provider_total_price))
        result_file.write('{} {} tickets were found in billing\n'.format(self.provider_tickets_count, self.provider))
        missed_provider_count = len(self.missed_provider_tickets)
        if missed_provider_count:
            result_file.write('{} {} tickets were missed in billing:\n'.format(missed_provider_count, self.provider))
            for ticket in self.missed_provider_tickets:
                result_file.write('{}\n'.format(ticket.to_file_str()))

        result_file.write('\n')
        result_file.write('Total price of billing tickets is {}\n'.format(self.billing_total_price))
        result_file.write('{} Billing orders were found in {}\n'.format(self.billing_orders_count, self.provider))
        missed_billing_count = len(self.missed_billing_orders)
        if missed_billing_count:
            result_file.write('{} Billing orders were missed in {}:\n'.format(missed_billing_count, self.provider))
            for order in self.missed_billing_orders:
                result_file.write('{}\n'.format(order.to_file_str()))

        different_count = len(self.different_price_orders)
        if different_count:
            result_file.write('\n')
            result_file.write('{} Different price orders were found\n'.format(different_count))
            for different_order in self.different_price_orders:
                result_file.write('{}\n'.format(different_order.to_file_str()))

        if not self.has_errors:
            result_file.write('\n')
            result_file.write('Differences were not found\n')

    @property
    def has_errors(self):
        # type: () -> bool
        return (
            len(self.missed_provider_tickets) > 0
            or len(self.missed_billing_orders) > 0
            or len(self.different_price_orders) > 0
        )


def compare_orders(provider, start_period_dt, end_period_dt, provider_tickets_by_ids, billing_orders_by_ids):
    # type: (str, datetime, datetime, Dict[int, ProviderTicket], Dict[int, BillingOrder]) -> OrdersCheckResult

    start_of_period = start_period_dt.strftime('%Y-%m-%dT%H:%M:%S')
    end_of_period = end_period_dt.strftime('%Y-%m-%dT%H:%M:%S')

    different_price_orders = []
    missed_provider_tickets = []
    provider_tickets_count = 0
    provider_total_price = 0
    for _id, ticket in provider_tickets_by_ids.items():
        if ticket.update_dt >= start_of_period:
            provider_total_price += ticket.price
            if _id in billing_orders_by_ids:
                provider_tickets_count += 1
                if float(billing_orders_by_ids[_id].price) != float(ticket.price):
                    different_price_orders.append(DifferentPriceOrder(
                        _id, billing_orders_by_ids[_id].price, ticket.price, ticket.update_dt
                    ))
            else:
                missed_provider_tickets.append(ticket)
    missed_provider_tickets.sort(key=lambda t: t.update_dt)
    different_price_orders.sort(key=lambda t: t.provider_update_dt)

    missed_billing_orders = []
    billing_orders_count = 0
    billing_total_price = 0
    for _id, order in billing_orders_by_ids.items():
        if order.update_dt <= end_of_period:
            billing_total_price += float(order.price)
            if _id in provider_tickets_by_ids:
                billing_orders_count += 1
            else:
                missed_billing_orders.append(order)
    missed_billing_orders.sort(key=lambda o: o.update_dt)

    return OrdersCheckResult(
        provider, start_period_dt, end_period_dt, provider_tickets_count, provider_total_price,
        missed_provider_tickets, billing_orders_count, billing_total_price, missed_billing_orders,
        different_price_orders
    )


class ComparePeriod(object):
    CUSTOM_PERIOD = 'custom_period'
    PREVIOUS_DAY = 'previous_day'
    PREVIOUS_DAY_MONTH = 'previous_day_month'
    PREVIOUS_MONTH = 'previous_month'


def get_period_dt(compare_period, start_date_str=None, end_date_str=None):
    # type: (str, str, str) -> (datetime, datetime)

    today = environment.today()

    if compare_period == ComparePeriod.CUSTOM_PERIOD:
        start_date = datetime.strptime(start_date_str, '%Y-%m-%d').date()
        end_date = datetime.strptime(end_date_str, '%Y-%m-%d').date() + timedelta(days=1)

    elif compare_period == ComparePeriod.PREVIOUS_DAY:
        start_date = today - timedelta(days=1)
        end_date = today

    elif compare_period == ComparePeriod.PREVIOUS_DAY_MONTH:
        start_date = (today - timedelta(days=1)).replace(day=1)
        end_date = today

    elif compare_period == ComparePeriod.PREVIOUS_MONTH:
        end_date = (today - timedelta(days=1)).replace(day=1)

        if end_date.month != 1:
            start_date = end_date.replace(month=end_date.month - 1)
        else:
            start_date = end_date.replace(year=end_date.year - 1, month=12)

    return datetime.combine(start_date, time(0)), datetime.combine(end_date, time(0))


def run(provider, compare_period, start_date=None, end_date=None):
    start_period_dt, end_period_dt = get_period_dt(compare_period, start_date, end_date)

    log.info('Get tickets from {}'.format(provider))
    if provider == 'movista':
        provider_tickets_by_ids = get_movista_tickets(start_period_dt, end_period_dt)
    else:
        provider_tickets_by_ids = get_im_tickets(start_period_dt, end_period_dt)
    log.info('{} {} tickets found'.format(len(provider_tickets_by_ids), provider))

    log.info('Get orders from CPA and billing')
    billing_orders_by_ids = get_billing_orders(provider, start_period_dt, end_period_dt)
    log.info('{} Billing orders found'.format(len(billing_orders_by_ids)))

    log.info('Compare {} and billing data'.format(provider))
    check_result = compare_orders(
        provider, start_period_dt, end_period_dt, provider_tickets_by_ids, billing_orders_by_ids
    )

    log.info('Saving results to file check_results')
    result_dir = settings.SUBURBAN_BILLING_COMPARE_RESULT_DATA_PATH
    if not os.path.exists(result_dir):
        os.mkdir(result_dir)
    with open(os.path.join(result_dir, 'result'), 'w') as result_file:
        check_result.save_to_file(result_file)

    log.info('Run done. {} missed {} tickets, {} missed billing orders'.format(
        check_result.missed_provider_tickets, provider, check_result.missed_billing_orders
    ))
    log.info('has_errors={}'.format(check_result.has_errors))

    return check_result.has_errors


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-p', '--provider', dest='provider')
    parser.add_argument('-c', '--compare-period', dest='compare_period')
    parser.add_argument('-s', '--start-date', dest='start_date')
    parser.add_argument('-e', '--end-date', dest='end_date')
    args = parser.parse_args()

    run(args.provider, args.compare_period, args.start_date, args.end_date)


if __name__ == '__main__':
    main()
