# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
import math
from collections import namedtuple
from decimal import Decimal, ROUND_HALF_UP

import grpc
from django.conf import settings
from grpc_opentracing import open_tracing_client_interceptor
from grpc_opentracing.grpcext import intercept_channel
from opentracing import global_tracer

from common.data_api.tvm.instance import tvm_factory
from common.models.currency import Price
from common.settings.configuration import Configuration
from common.settings.utils import define_setting
from common.utils.yasmutil import Metric, MeasurableDecorator
from google.protobuf.timestamp_pb2 import Timestamp
from travel.proto import commons_pb2
from travel.rasp.train_bandit_api.proto import api_pb2, api_pb2_grpc

log = logging.getLogger(__name__)

BanditContext = namedtuple(
    'BanditContext',
    ('icookie', 'point_from', 'point_to', 'departure', 'arrival', 'train_type', 'car_type', 'in_suburban_search'),
)
SegmentData = namedtuple('SegmentData', ('context', 'amount', 'service_amount', 'index', 'service_class', 'log_info'))
RequestContext = namedtuple('RequestContext', ('req_id', 'yandex_uid', 'user_device'))
Fee = namedtuple('Fee', ('fee', 'fee_percent', 'is_bandit_fee_applied', 'bandit_type', 'bandit_version'))
CoachData = namedtuple('CoachData', ('context', 'places', 'service_class', 'log_info'))
PlaceData = namedtuple('PlaceData', ('id', 'amount', 'service_amount'))
PlaceFee = namedtuple(
    'PlaceFee',
    ('main_fee', 'service_fee', 'fee_percent', 'is_bandit_fee_applied', 'bandit_type', 'bandit_version'),
)

PRECISION = 2
RUB_CURRENCY = 'RUB'

define_setting('BANDIT_API_ENDPOINT', {
    Configuration.PRODUCTION: 'travel-trains-bandit-production.balancer.internal.yandex.net:9111',
}, default='travel-trains-bandit-testing.balancer.internal.yandex.net:9111')

define_setting('BANDIT_API_TIMEOUT', default=2, converter=int)


class BanditMeasurable(MeasurableDecorator):
    prefix = 'train-bandit-api'

    def _handle_error(self, exc):
        result = super(BanditMeasurable, self)._handle_error(exc)
        result.extend([
            Metric(self._name('errors_cnt'), 1, 'ammm'),
            Metric('errors_cnt', 1, 'ammm'),
        ])
        return result


class BanditClientException(Exception):
    pass


class BanditClient(object):
    bandit_type = None
    stub = None

    def __init__(self, bandit_type, tvm_bandit_id=settings.TVM_BANDIT, tvm=tvm_factory):
        self.tvm_bandit_id = tvm_bandit_id
        self.tvm_factory = tvm
        channel = grpc.insecure_channel(settings.BANDIT_API_ENDPOINT, options=None)
        interceptor = open_tracing_client_interceptor(global_tracer())
        channel = intercept_channel(channel, interceptor)
        self.stub = api_pb2_grpc.BanditApiServiceV1Stub(channel)
        self.bandit_type = bandit_type

    @classmethod
    def _quantize(cls, decimal):
        return decimal.quantize(Decimal('0.00'), rounding=ROUND_HALF_UP)

    @classmethod
    def _create_ticket_price(cls, amount_price, service_amount_price):
        amount = int(float(amount_price) * math.pow(10, PRECISION) + 0.5)
        service_amount = int(float(service_amount_price) * math.pow(10, PRECISION) + 0.5)
        ticket_price = api_pb2.TTicketPrice(
            Amount=commons_pb2.TPrice(Amount=amount, Precision=PRECISION,
                                      Currency=commons_pb2.C_RUB),
            ServiceAmount=commons_pb2.TPrice(Amount=service_amount, Precision=PRECISION,
                                             Currency=commons_pb2.C_RUB),
        )
        return ticket_price

    def get_fee_for_segments(self, segment_datas):
        """
        Fill segments with fee
        :param segment_datas: type: list of SegmentData
        :return: type: dict of (SegmentData, Fee)
        """
        contexts_with_prices = []
        segments_with_fee = {}
        context_to_index = {}
        if not segment_datas:
            return segments_with_fee

        segment_datas = list(segment_datas)
        for i, segment_data in enumerate(segment_datas):
            if (segment_data.amount.currency != RUB_CURRENCY
                    or segment_data.service_amount.currency != RUB_CURRENCY):
                raise BanditClientException('Currency is not available')

            context = api_pb2.TContext(
                ICookie=segment_data.context.icookie,
                PointFrom=segment_data.context.point_from,
                PointTo=segment_data.context.point_to,
                Arrival=Timestamp(seconds=segment_data.context.arrival),
                Departure=Timestamp(seconds=segment_data.context.departure),
                TrainType=segment_data.context.train_type,
                CarType=segment_data.context.car_type,
                InSuburbanSearch=segment_data.context.in_suburban_search,
            )
            log_info = api_pb2.TAdditionalLogInfo(
                EventType=api_pb2.ET_TRAIN_TARIFFS,
                YandexUID=segment_data.log_info.yandex_uid,
                UserDevice=segment_data.log_info.user_device,
                ReqID=segment_data.log_info.req_id,
                ServiceClass=segment_data.service_class,
            )
            ticket_price = self._create_ticket_price(
                amount_price=segment_data.amount.value,
                service_amount_price=(segment_data.service_amount.value if segment_data.service_amount else 0),
            )
            context_str = context.SerializeToString()
            if context_str in context_to_index:
                context_with_price = contexts_with_prices[context_to_index[context_str]]
            else:
                context_to_index[context_str] = len(contexts_with_prices)
                context_with_price = api_pb2.TTicketPrices(
                    Context=context,
                    InternalId=len(contexts_with_prices),
                    AdditionalLogInfo=log_info,
                    TicketPrices={}
                )
                contexts_with_prices.append(context_with_price)
            context_with_price.TicketPrices[i].CopyFrom(ticket_price)

        request = api_pb2.TGetChargeRequest(
            BanditType=self.bandit_type,
            ContextsWithPrices=contexts_with_prices,
        )
        response = self._get_charge(request)

        for charge in response.ChargesByContexts:
            for i, ticket_fee in charge.TicketFees.items():
                segment_data = segment_datas[i]
                full_fee = (Decimal(ticket_fee.Fee.Amount / math.pow(10, ticket_fee.Fee.Precision))
                            + Decimal(ticket_fee.ServiceFee.Amount / math.pow(10, ticket_fee.ServiceFee.Precision)))

                segment_fee = Fee(
                    fee_percent=self._quantize(Decimal(charge.Permille / 1000)),
                    fee=Price(value=self._quantize(full_fee), currency=RUB_CURRENCY),
                    is_bandit_fee_applied=ticket_fee.IsBanditFeeApplied,
                    bandit_type=charge.BanditType,
                    bandit_version=charge.BanditVersion,
                )
                segments_with_fee[segment_data] = segment_fee

        return segments_with_fee

    @staticmethod
    def to_proto_context(context):
        return api_pb2.TContext(
            ICookie=context.icookie,
            PointFrom=context.point_from,
            PointTo=context.point_to,
            Arrival=Timestamp(seconds=context.arrival),
            Departure=Timestamp(seconds=context.departure),
            TrainType=context.train_type,
            CarType=context.car_type,
            InSuburbanSearch=context.in_suburban_search,
        )

    @BanditMeasurable(endpoint_name='GetCharge')
    def _get_charge(self, request):
        service_ticket = self.tvm_factory.get_provider().get_ticket(self.tvm_bandit_id)
        response = self.stub.GetCharge(request, timeout=settings.BANDIT_API_TIMEOUT, metadata=(
            ('x-ya-service-ticket', service_ticket),
        ))
        return response

    def get_fee_for_coaches(self, coach_datas):
        """
        :param coach_datas: type: list of CoachData
        :return: type: dict of (PlaceData, PlaceFee)
        """
        contexts_with_prices = []
        place_datas_by_id = {}
        place_fees_by_place_data = {}
        context_to_index = {}

        for coach_data in coach_datas:
            context = self.to_proto_context(coach_data.context)
            # log_info = api_pb2.TAdditionalLogInfo(
            #     EventType=api_pb2.ET_TRAIN_DETAILS,
            #     YandexUID=coach_data.log_info.yandex_uid,
            #     UserDevice=coach_data.log_info.user_device,
            #     ReqID=coach_data.log_info.req_id,
            #     ServiceClass=coach_data.service_class,
            # )

            context_str = context.SerializeToString()
            if context_str in context_to_index:
                context_with_price = contexts_with_prices[context_to_index[context_str]]
            else:
                context_to_index[context_str] = len(contexts_with_prices)
                context_with_price = api_pb2.TTicketPrices(
                    Context=context,
                    InternalId=len(contexts_with_prices),
                    # bandit-side логирование TRAIN_DETAILS отключено
                    # AdditionalLogInfo=log_info,
                    TicketPrices={}
                )
                contexts_with_prices.append(context_with_price)

            for place_data in coach_data.places:
                ticket_price = self._create_ticket_price(
                    place_data.amount, place_data.service_amount
                )
                context_with_price.TicketPrices[place_data.id].CopyFrom(ticket_price)
                place_datas_by_id[place_data.id] = place_data

        request = api_pb2.TGetChargeRequest(
            BanditType=self.bandit_type,
            ContextsWithPrices=contexts_with_prices,
        )
        response = self._get_charge(request)

        for charge in response.ChargesByContexts:
            for place_id, ticket_fee in charge.TicketFees.items():
                if place_id not in place_datas_by_id:
                    log.error('Redundant place fee: place_id=%d fee=%s' % (place_id, str(ticket_fee)))
                    continue
                place_data = place_datas_by_id[place_id]
                main_fee = Decimal(ticket_fee.Fee.Amount / math.pow(10, ticket_fee.Fee.Precision))
                service_fee = Decimal(ticket_fee.ServiceFee.Amount / math.pow(10, ticket_fee.ServiceFee.Precision))
                place_fee = PlaceFee(
                    fee_percent=self._quantize(Decimal(charge.Permille / 1000)),
                    main_fee=self._quantize(main_fee),
                    service_fee=self._quantize(service_fee),
                    is_bandit_fee_applied=ticket_fee.IsBanditFeeApplied,
                    bandit_type=charge.BanditType,
                    bandit_version=charge.BanditVersion,
                )
                place_fees_by_place_data[place_data] = place_fee

        return place_fees_by_place_data
