from __future__ import print_function

import argparse
import collections
import logging
import socket
import sys
import uuid
from datetime import date, datetime, timedelta

import google.protobuf.timestamp_pb2 as google_pb2
import grpc
import travel.orders.proto.services.orders.orders_pb2 as orders_pb2
import travel.orders.proto.services.orders.orders_pb2_grpc as orders_pb2_grpc
from dateutil.relativedelta import relativedelta
from parse import parse
from travel.orders.tools.library.tvm import get_tvm_service_ticket


class _GenericClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
                                grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):
    def __init__(self, interceptor_function):
        self._fn = interceptor_function

    def intercept_unary_unary(self, continuation, client_call_details, request):
        new_details, new_request_iterator, postprocess = self._fn(
            client_call_details, iter((request,)), False, False)
        response = continuation(new_details, next(new_request_iterator))
        return postprocess(response) if postprocess else response

    def intercept_unary_stream(self, continuation, client_call_details, request):
        new_details, new_request_iterator, postprocess = self._fn(
            client_call_details, iter((request,)), False, True)
        response_it = continuation(new_details, next(new_request_iterator))
        return postprocess(response_it) if postprocess else response_it

    def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
        new_details, new_request_iterator, postprocess = self._fn(
            client_call_details, request_iterator, True, False)
        response = continuation(new_details, new_request_iterator)
        return postprocess(response) if postprocess else response

    def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
        new_details, new_request_iterator, postprocess = self._fn(
            client_call_details, request_iterator, True, True)
        response_it = continuation(new_details, new_request_iterator)
        return postprocess(response_it) if postprocess else response_it


class _ClientCallDetails(collections.namedtuple('_ClientCallDetails', ('method', 'timeout', 'metadata', 'credentials')),
                         grpc.ClientCallDetails):
    pass


def create_grpc_channel(args):
    ticket = args.tvm_service_ticket or get_tvm_service_ticket(
        tvm_client_id=args.tvm_client_id,
        tvm_service_id=args.tvm_service_id,
        tvm_client_secret=args.tvm_client_secret,
        skip_authentication=args.skip_authentication
    )

    def ya_grpc_interceptor(client_call_details, request_iterator,
                            request_streaming, response_streaming):
        ya_call_id = str(uuid.uuid1())
        ya_started_at = datetime.utcnow().isoformat() + "Z"
        ya_fqdn = socket.getfqdn()

        metadata = {}
        if client_call_details.metadata is not None:
            metadata = dict(client_call_details.metadata)
        metadata["ya-grpc-call-id"] = ya_call_id
        metadata["ya-grpc-started-at"] = ya_started_at
        metadata["ya-grpc-fqdn"] = ya_fqdn
        metadata["x-ya-service-ticket"] = ticket

        timeout = args.timeout
        if client_call_details.timeout is not None:
            timeout = client_call_details.timeout

        client_call_details = _ClientCallDetails(
            client_call_details.method,
            timeout,
            list(metadata.items()),
            client_call_details.credentials)

        postprocess = None

        assert not request_streaming
        assert not response_streaming

        request = next(request_iterator)
        logging.info("<- REQ %s (CallId: %s)", client_call_details.method, ya_call_id)
        logging.debug("%s\n%r", request.__class__, request)

        def complete(rendezvous):
            logging.info("-> RSP %s (CallId: %s, Code: %s, Cancelled: %s)",
                         client_call_details.method, ya_call_id,
                         rendezvous.code(), rendezvous.cancelled())

            metadata = {}
            for item in rendezvous.initial_metadata():
                metadata[item.key] = item.value
            for item in rendezvous.trailing_metadata():
                metadata[item.key] = item.value
            for key, value in sorted(metadata.items()):
                logging.debug("Meta: [%s] => %r", key, value)

            if not rendezvous.cancelled():
                response = rendezvous.result()
                logging.debug("%s\n%r", response.__class__, response)

        def postprocess(rendezvous):
            cb = lambda: complete(rendezvous)
            if not rendezvous.add_callback(cb):
                cb()
            return rendezvous

        return client_call_details, iter((request,)), postprocess

    endpoint = "%s:%s" % (args.host, args.port)
    logging.info("Creating a channel to '%s'", endpoint)
    channel = grpc.insecure_channel(endpoint)
    channel = grpc.intercept_channel(channel, _GenericClientInterceptor(ya_grpc_interceptor))
    return channel


def handle_createPartnerPayoutReport(args):
    period_start_date = parse_datetime_iso(args.period_start).date()
    period_end_date = parse_datetime_iso(args.period_end)

    req = orders_pb2.TDevCreatePartnerPayoutReportReq(
        BillingPartnerId=int(args.billing_client_id),
        BillingContractId=int(args.billing_contract_id),
        ExternalContractId=args.external_contract_id,
        PeriodStart=google_pb2.Timestamp(
            seconds=seconds_from_date(period_start_date)
        ),
        PeriodEnd=google_pb2.Timestamp(
            seconds=seconds_from_date(period_end_date)
        )
    )
    with create_grpc_channel(args) as channel:
        stub = orders_pb2_grpc.ReportDevInterfaceDoNotUseStub(channel)
        rsp = stub.CreatePartnerPayoutReport(req)
        logging.info("Got response. Saving it to file {}".format(args.output_file))

        with open(args.output_file, "wb") as output:
            output.write(rsp.Bytes)


def handle_createPartnerOrdersReport(args):
    period_start_date = parse_datetime_iso(args.period_start).date()
    period_end_date = parse_datetime_iso(args.period_end)

    req = orders_pb2.TDevCreatePartnerOrdersReportReq(
        BillingPartnerId=int(args.billing_client_id),
        PeriodStart=google_pb2.Timestamp(
            seconds=seconds_from_date(period_start_date)
        ),
        PeriodEnd=google_pb2.Timestamp(
            seconds=seconds_from_date(period_end_date)
        )
    )
    with create_grpc_channel(args) as channel:
        stub = orders_pb2_grpc.ReportDevInterfaceDoNotUseStub(channel)
        rsp = stub.CreatePartnerOrdersReport(req)
        logging.info("Got response. Saving it to file {}".format(args.output_file))

        with open(args.output_file, "wb") as output:
            output.write(rsp.Bytes)


def handle_createPartnerPaymentOrderReport(args):
    req = orders_pb2.TDevCreatePartnerPaymentOrderReportReq(
        PaymentOrderId=args.payment_order_id,
        PaymentBatchId=args.payment_batch_id
    )
    with create_grpc_channel(args) as channel:
        stub = orders_pb2_grpc.ReportDevInterfaceDoNotUseStub(channel)
        rsp = stub.CreatePartnerPaymentOrderReport(req)
        logging.info("Got response. Saving it to file {}".format(args.output_file))

        with open(args.output_file, "wb") as output:
            output.write(rsp.Bytes)


def handle_sendPartnerPaymentOrderReport(args):
    req = orders_pb2.TDevSendPartnerPaymentOrderReportReq(
        PaymentOrderId=args.payment_order_id,
        PaymentBatchId=args.payment_batch_id,
        Email=args.email
    )
    with create_grpc_channel(args) as channel:
        stub = orders_pb2_grpc.ReportDevInterfaceDoNotUseStub(channel)
        rsp = stub.SendPartnerPaymentOrderReport(req)
        logging.info("Report sending scheduled. Operation id {}".format(rsp.OperationId))


def handle_sendPartnerReports(args):
    period_start_date = parse_datetime_iso(args.period_start).date()
    period_end_date = parse_datetime_iso(args.period_end)

    req = orders_pb2.TDevSendPartnerReportsReq(
        BillingPartnerId=int(args.billing_client_id),
        PeriodStart=google_pb2.Timestamp(
            seconds=seconds_from_date(period_start_date)
        ),
        PeriodEnd=google_pb2.Timestamp(
            seconds=seconds_from_date(period_end_date)
        ),
        Email=args.email,
        ReportAt=google_pb2.Timestamp(
            seconds=int((datetime.now() - datetime(1970, 1, 1)).total_seconds())
        ),
        ReportType=args.report_type
    )
    with create_grpc_channel(args) as channel:
        stub = orders_pb2_grpc.ReportDevInterfaceDoNotUseStub(channel)
        rsp = stub.SendPartnerReports(req)
        logging.info("Report sending scheduled. Operation id {}".format(rsp.OperationId))


def handle_planPartnerReportsSending(args):
    period_start_date = parse_datetime_iso(args.period_start).date()
    period_end_date = parse_datetime_iso(args.period_end).date()

    req = orders_pb2.TDevPlanPartnerReportsSendingReq(
        PeriodStart=google_pb2.Timestamp(
            seconds=seconds_from_date(period_start_date)
        ),
        PeriodEnd=google_pb2.Timestamp(
            seconds=seconds_from_date(period_end_date)
        ),
        ReportType=args.report_type,
    )
    with create_grpc_channel(args) as channel:
        stub = orders_pb2_grpc.ReportDevInterfaceDoNotUseStub(channel)
        rsp = stub.PlanPartnerReportsSending(req)
        logging.info("Report sending scheduled. Operation id {}".format(rsp.OperationId))


def parse_datetime_iso(dt_text):
    return parse('{:ti}', dt_text)[0]


def seconds_from_date(d):
    return int((datetime.combine(d, datetime.min.time()) - datetime(1970, 1, 1)).total_seconds())


def last_day_of_month(any_day):
    next_month = any_day.replace(day=28) + timedelta(days=4)
    return next_month - timedelta(days=next_month.day)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-v", "--verbose", action="store_true", default=False)
    parser.add_argument("--host", default="localhost")
    parser.add_argument("--port", default=30858)
    parser.add_argument("--timeout", default=3, type=int)
    parser.add_argument("--tvm-client-id", default=2010758, type=int)
    parser.add_argument("--tvm-service-id", default=2002740, type=int)
    parser.add_argument("--tvm-client-secret")
    parser.add_argument("--tvm-service-ticket")
    parser.add_argument("--skip-authentication", action='store_true', default=False)
    subparsers = parser.add_subparsers()

    start = (date.today() + relativedelta(months=-1)).replace(day=1)
    end = last_day_of_month(start)

    createPartnerPayoutReport_parser = subparsers.add_parser("createPartnerPayoutReport")
    createPartnerPayoutReport_parser.set_defaults(func=handle_createPartnerPayoutReport)
    createPartnerPayoutReport_parser.add_argument("--billing-client-id")
    createPartnerPayoutReport_parser.add_argument("--billing-contract-id")
    createPartnerPayoutReport_parser.add_argument("--external-contract-id")
    createPartnerPayoutReport_parser.add_argument("--period-start",
                                                  default=start.isoformat())
    createPartnerPayoutReport_parser.add_argument("--period-end", default=end.isoformat())
    createPartnerPayoutReport_parser.add_argument("--output-file", default="partner_payouts.xlsx")

    createPartnerOrdersReport_parser = subparsers.add_parser("createPartnerOrdersReport")
    createPartnerOrdersReport_parser.set_defaults(func=handle_createPartnerOrdersReport)
    createPartnerOrdersReport_parser.add_argument("--billing-client-id")
    createPartnerOrdersReport_parser.add_argument("--period-start",
                                                  default=start.isoformat())
    createPartnerOrdersReport_parser.add_argument("--period-end", default=end.isoformat())
    createPartnerOrdersReport_parser.add_argument("--output-file", default="partner_orders.xlsx")

    createPartnerPaymentOrderReport_parser = subparsers.add_parser("createPartnerPaymentOrderReport")
    createPartnerPaymentOrderReport_parser.set_defaults(func=handle_createPartnerPaymentOrderReport)
    createPartnerPaymentOrderReport_parser.add_argument("--payment-order-id")
    createPartnerPaymentOrderReport_parser.add_argument("--payment-batch-id")
    createPartnerPaymentOrderReport_parser.add_argument("--output-file", default="partner_payment_orders.xlsx")

    sendPartnerPaymentOrderReport_parser = subparsers.add_parser("sendPartnerPaymentOrderReport")
    sendPartnerPaymentOrderReport_parser.set_defaults(func=handle_sendPartnerPaymentOrderReport)
    sendPartnerPaymentOrderReport_parser.add_argument("--payment-order-id")
    sendPartnerPaymentOrderReport_parser.add_argument("--payment-batch-id")
    sendPartnerPaymentOrderReport_parser.add_argument("--email")

    sendPartnerReports_parser = subparsers.add_parser("sendPartnerReports")
    sendPartnerReports_parser.set_defaults(func=handle_sendPartnerReports)
    sendPartnerReports_parser.add_argument("--billing-client-id")
    sendPartnerReports_parser.add_argument("--period-start", default=start.isoformat())
    sendPartnerReports_parser.add_argument("--period-end", default=end.isoformat())
    sendPartnerReports_parser.add_argument("--email")
    sendPartnerReports_parser.add_argument("--report-type", default="orders")

    planPartnerReportsSending_parser = subparsers.add_parser("planPartnerReportsSending")
    planPartnerReportsSending_parser.set_defaults(func=handle_planPartnerReportsSending)
    planPartnerReportsSending_parser.add_argument("--period-start", default=start.isoformat())
    planPartnerReportsSending_parser.add_argument("--period-end", default=end.isoformat())
    planPartnerReportsSending_parser.add_argument("--report-type", default="orders")

    args = parser.parse_args()
    logging.basicConfig(level=(logging.DEBUG if args.verbose else logging.INFO),
                        format="%(asctime)-15s | %(levelname)s | %(message)s",
                        stream=sys.stdout)

    args.func(args)


if __name__ == "__main__":
    main()
