import asyncio
import logging
import math
from concurrent import futures
from datetime import datetime

import grpc
from google.protobuf.timestamp_pb2 import Timestamp
from grpc_opentracing import open_tracing_client_interceptor
from grpc_opentracing.grpcext import intercept_channel
from opentracing import global_tracer

from travel.proto import commons_pb2
from travel.rasp.train_bandit_api.proto import api_pb2, api_pb2_grpc
from travel.rasp.pathfinder_proxy.const import TTransport, UTC_TZ

log = logging.getLogger(__name__)

PRECISION = 2


class TrainFeeService(object):
    def __init__(self, tvm, settings):
        self.endpoint = settings.BANDIT_API_ENDPOINT
        self.timeout = settings.BANDIT_API_TIMEOUT
        self.tvm_bandit_id = settings.TVM_DESTINATIONS.get('TVM_BANDIT')
        self.tvm_factory = tvm
        self._executer = futures.ThreadPoolExecutor(max_workers=30)

    @classmethod
    def _create_ticket_price(cls, amount_price, service_amount_price):
        amount = int(float(amount_price) * math.pow(10, PRECISION) + 0.5)
        service_amount = int(float(service_amount_price) * math.pow(10, PRECISION) + 0.5)
        ticket_price = api_pb2.TTicketPrice(
            Amount=commons_pb2.TPrice(Amount=amount, Precision=PRECISION,
                                      Currency=commons_pb2.C_RUB),
            ServiceAmount=commons_pb2.TPrice(Amount=service_amount, Precision=PRECISION,
                                             Currency=commons_pb2.C_RUB),
        )
        return ticket_price

    @classmethod
    def _get_timestamp(cls, dt):
        if isinstance(dt, str):
            dt = datetime.fromisoformat(dt)
        return int(dt.astimezone(UTC_TZ).timestamp())

    async def apply_fee(self, transfer_variants, icookie, bandit_type, yandex_uid, user_device, req_id):
        if not transfer_variants or not icookie:
            return transfer_variants

        transfer_variants = list(transfer_variants)
        price_index = 0
        contexts_with_prices = []
        context_to_index = {}
        class_data_by_price_index = {}
        for transfer_index, transfer_variant in enumerate(transfer_variants):
            if not transfer_variant.get('segments'):
                continue
            segment_datas = list(transfer_variant['segments'])
            for segment_index, segment_data in enumerate(segment_datas):
                transport = segment_data['transport']['code']
                if transport != TTransport.get_name(TTransport.TRAIN):
                    continue
                if not segment_data.get('tariffs') or not segment_data['tariffs'].get('classes'):
                    continue
                classes = segment_data['tariffs']['classes']
                for car_type, class_data in classes.items():
                    if not class_data['price']:
                        continue
                    context = api_pb2.TContext(
                        ICookie=icookie,
                        PointFrom='s{}'.format(segment_data['stationFrom']['id']),
                        PointTo='s{}'.format(segment_data['stationTo']['id']),
                        Arrival=Timestamp(seconds=self._get_timestamp(segment_data['arrival'])),
                        Departure=Timestamp(seconds=self._get_timestamp(segment_data['departure'])),
                        TrainType=segment_data['rawTrainName'],
                        CarType=car_type,
                    )
                    log_info = api_pb2.TAdditionalLogInfo(
                        EventType=api_pb2.ET_TRAIN_TARIFFS,
                        YandexUID=yandex_uid,
                        UserDevice=user_device,
                        ReqID=req_id,
                    )
                    service_price = 0.0
                    if class_data.get('servicePrice'):
                        service_price = class_data['servicePrice']['value']
                    ticket_price = self._create_ticket_price(
                        amount_price=class_data['price']['value'],
                        service_amount_price=service_price,
                    )
                    context_str = context.SerializeToString()
                    if context_str in context_to_index:
                        context_with_price = contexts_with_prices[context_to_index[context_str]]
                    else:
                        context_to_index[context_str] = len(contexts_with_prices)
                        context_with_price = api_pb2.TTicketPrices(
                            Context=context,
                            InternalId=len(contexts_with_prices),
                            AdditionalLogInfo=log_info,
                            TicketPrices={}
                        )
                        contexts_with_prices.append(context_with_price)
                    context_with_price.TicketPrices[price_index].CopyFrom(ticket_price)
                    class_data_by_price_index[price_index] = class_data
                    price_index += 1

        if len(contexts_with_prices) == 0:
            return transfer_variants

        request = api_pb2.TGetChargeRequest(
            BanditType=bandit_type,
            ContextsWithPrices=contexts_with_prices,
        )
        response = await self._async_get_charge(request)

        for charge in response.ChargesByContexts:
            for i, ticket_fee in charge.TicketFees.items():
                class_data = class_data_by_price_index[i]
                full_fee = (ticket_fee.Fee.Amount / math.pow(10, ticket_fee.Fee.Precision)
                            + ticket_fee.ServiceFee.Amount / math.pow(10, ticket_fee.ServiceFee.Precision))
                class_data['price']['value'] = round(full_fee + class_data['price']['value'], 2)

        return transfer_variants

    def _get_charge(self, request):
        channel = grpc.insecure_channel(self.endpoint, options=None)
        interceptor = open_tracing_client_interceptor(global_tracer())
        channel = intercept_channel(channel, interceptor)
        stub = api_pb2_grpc.BanditApiServiceV1Stub(channel)
        service_ticket = self.tvm_factory.get_provider().get_ticket(self.tvm_bandit_id)
        response = stub.GetCharge(request, timeout=self.timeout, metadata=(
            ('x-ya-service-ticket', service_ticket),
        ))
        return response

    async def _async_get_charge(self, request):
        loop = asyncio.get_event_loop()
        future = loop.run_in_executor(self._executer, self._get_charge, request)
        return await asyncio.wait_for(future, timeout=1)
