# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import random
from contextlib import closing
from datetime import datetime, timedelta
from typing import Dict  # noqa

import grpc
from grpc._channel import _InactiveRpcError

from travel.avia.library.python.price_prediction.deploy_helpers import discover
from travel.avia.price_prediction.api.v1 import check_price_pb2, check_price_pb2_grpc

logger = logging.getLogger(__name__)


class DiscoveryError(Exception):
    pass


class PricePredictionClient(object):
    """
    Клиент подойдет при использовании L3-балансера,
    или если хотим всегда ходить на один хост, например в dev
    """
    def __init__(self, host, port=9001, options=None):
        self.host = host
        self.port = port
        self.options = options

    def build_target(self):
        return '{host}:{port}'.format(host=self.host, port=self.port)

    def check_prices(self, request, timeout, tvm_service_ticket):
        # type: (PricePredictionClient, check_price_pb2.TCheckPricesReq, float) -> Dict[int, check_price_pb2.TPricePrediction.ECategory]
        target = self.build_target()
        with closing(grpc.insecure_channel(target, self.options)) as channel:
            try:
                stub = check_price_pb2_grpc.CheckPriceServiceStub(channel)
                result = stub.CheckPrices(request, timeout, metadata=(
                    ('x-ya-service-ticket', tvm_service_ticket),
                ))
            except _InactiveRpcError as e:
                if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
                    # TODO simon-ekb: send to solomon RASPTICKETS-20639
                    pass
                self.on_error()
                return _build_unknown_result(request.CheckPricesReq)
            except Exception:
                logger.exception('Error in check prices')
                self.on_error()
                return _build_unknown_result(request.CheckPricesReq)
            else:
                return result.PriceCategories

    def on_error(self):
        return


class RediscoveringPricePredictionClient(PricePredictionClient):
    """
    Клиент для client-side балансировки.
    Дискавери эндпоинтов в зависимости от окружения сервера. Дискавери происходит при инициализации,
    дальше при поступлении запроса, если дискавери не производили дольше, чем rediscover_interval.
    Выбирается произвольный из живых серверов. Если живых серверов нет, то выбирается произвольный из найденных.
    """
    def __init__(self, server_environment, options=None, rediscover_interval=timedelta(minutes=5)):
        self.rediscover_interval = rediscover_interval
        self.environment = server_environment
        self.options = options
        self.endpoints = None
        self.current_endpoint_index = 0
        self.last_discovery_try_dt = None

    def build_target(self):
        if not self.endpoints:
            raise DiscoveryError('Host is not discovered for price-prediction {}'.format(self.environment))

        endpoint = self.endpoints[self.current_endpoint_index]
        return '{host}:{port}'.format(host=endpoint.host, port=endpoint.port)

    def rediscover(self):
        # 1. Хотим чтобы как можно меньше запросов делали rediscover, поэтому первым делом проставим discovered_at.
        # 2. В случае ошибки не хотим делать discover на каждый запрос, поэтому выставление discovered_at заранее - ок
        self.last_discovery_try_dt = datetime.now()
        try:
            self.endpoints = discover(self.environment)
        except Exception:
            logger.exception('Error while discover hosts')
        else:
            random.shuffle(self.endpoints)
            self.current_endpoint_index = 0

    def on_error(self):
        self.shift_endpoint_index()

    def shift_endpoint_index(self):
        if not self.endpoints:
            return

        self.current_endpoint_index = (self.current_endpoint_index + 1) % len(self.endpoints)


def _build_unknown_result(check_price_request):
    return dict.fromkeys(check_price_request, check_price_pb2.TPricePrediction.CATEGORY_UNKNOWN)
