# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

from datetime import datetime
from decimal import Decimal
import json
import logging
import pytz
import six
import time as os_time

from django.conf import settings
from google.protobuf.json_format import MessageToJson
import grpc
from grpc_opentracing import open_tracing_client_interceptor
from grpc_opentracing.grpcext import intercept_channel
from opentracing import global_tracer

from common.apps.train_order.enums import CoachType
from common.dynamic_settings.default import conf
from common.models.geo import Station
from common.models.tariffs import Setting
from common.settings.utils import define_setting, bool_converter
from travel.proto import commons_pb2
from travel.proto.dicts.rasp import thread_title_pb2
from travel.rasp.train_api.helpers.rps_limiter import rps_limiter, RPS_LIMITER_IM_SEARCH_KEY
from travel.rasp.train_api.tariffs.train.base.availability_indication import AvailabilityIndication
from travel.rasp.train_api.tariffs.train.base.models import TrainSegment, TrainTariff, Price, PlaceReservationType
from travel.rasp.train_api.tariffs.train.base.segments import fill_segment_fees
from travel.rasp.train_api.tariffs.train.base.service import WorkerNetworkError, WorkerParseError, UnknownResultType
from travel.rasp.train_api.tariffs.train.base.worker import (
    TrainTariffsResult, BaseErrorResult, PoolWorker, WorkerEmptyResultError
)
from travel.rasp.train_api.tariffs.train.im.parser import build_train_segments
from travel.rasp.train_api.tariffs.train.im.scarab_logger import log_response
from travel.rasp.train_api.tariffs.train.segment_builder.plain_prices import build_price_segments
from travel.rasp.train_api.tariffs.train.segment_builder.prices_and_reasons import (
    build_price_segments_with_reason_for_missing_prices
)
from travel.rasp.train_api.tariffs.train.segment_builder.search_segments import (
    get_search_segments_with_keys, fill_start_and_end_stations_from_thread
)
from travel.rasp.train_api.tariffs.train.wizard.service import WizardThreadInfo
from travel.rasp.train_api.train_partners.im.base import ImError, get_im_response, measurable
from travel.rasp.train_api.train_partners.mock_im import get_mock_im_train_pricing
from travel.trains.worker.api import worker_service_pb2
from travel.trains.worker.api import worker_service_pb2_grpc

log = logging.getLogger(__name__)


TRAIN_PRICING_ENDPOINT = 'Railway/V1/Search/TrainPricing'

define_setting("USE_WORKER", default=False, converter=bool_converter)
define_setting('WORKER_ENDPOINT', default='localhost:9001')


class ImTariffsResult(TrainTariffsResult):
    @property
    def cache_timeout(self):
        if self.status == self.STATUS_PENDING:
            return settings.TARIFF_SUPPLIERWAIT_TIMEOUT
        elif self.status == self.STATUS_SUCCESS:
            return Setting.get('UFS_CACHE_TIMEOUT') * 60
        elif self.status == self.STATUS_ERROR:
            if isinstance(self.error, WorkerNetworkError):  # не кешируем сетевые ошибки
                return 0
            elif self.is_empty_result():
                return Setting.get('UFS_EMPTY_TIMEOUT') * 60
            elif isinstance(self.error, ImError):
                return self.get_im_error_cache_timeout()
            elif isinstance(self.error, WorkerParseError):
                return Setting.get('UFS_PARSE_ERROR_TIMEOUT') * 60

        raise UnknownResultType('Unknown result type status={} error={}'.format(self.status, self.error.__class__))

    def get_im_error_cache_timeout(self):
        if self.error.is_communication_error() or self.error.is_retry_allowed() or self.error.is_non_cacheable_error():
            return 0
        if self.error.is_trains_not_found_error():
            return Setting.get('UFS_EMPTY_TIMEOUT') * 60
        return Setting.get('UFS_RESPONSE_ERROR_TIMEOUT') * 60


class ErrorResult(BaseErrorResult):
    ResultClass = ImTariffsResult


def do_im_query(train_query, include_reason_for_missing_prices):
    result = ImTariffsResult(train_query, ImTariffsResult.STATUS_PENDING)
    result.update_cache()

    active_span = getattr(global_tracer().scope_manager.active, 'span', None)
    PoolWorker(target=get_result_from_partner_and_add_schedule_info_and_save_to_cache,
               args=(train_query, include_reason_for_missing_prices, active_span)).start()

    return result


def get_result_from_partner_and_add_schedule_info_and_save_to_cache(train_query, include_reason_for_missing_prices, active_span):
    log.info('Обрабатываем запрос %s', train_query)
    with global_tracer().start_active_span('train_api.tariffs:get_result_from_partner_and_add_schedule_info_and_save_to_cache', finish_on_close=True) as scope:
        scope.span.set_tag('train_query', repr(train_query))
        _link_span_to_parent(scope.span, active_span)
        result = get_result_from_partner_and_add_schedule_info(train_query, include_reason_for_missing_prices)
        result.update_cache()


def get_result_from_partner_and_add_schedule_info(train_query, include_reason_for_missing_prices):
    try:
        if include_reason_for_missing_prices:
            result = _get_result_with_reasons_from_partner_and_add_schedule_info(train_query)
        else:
            result = _get_result_from_partner_and_add_schedule_info(train_query)
    except ErrorResult as error_result:
        result = error_result.result
    return result


def _link_span_to_parent(current_span, origin_span):
    current_span.set_tag('origin.span_id', getattr(origin_span, 'span_id', 'empty_span_id'))
    try:
        current_span.set_tag('origin.trace_id', hex(int(getattr(origin_span, 'trace_id', '0')))[2:])
    except ValueError:
        pass


def _get_result_from_partner_and_add_schedule_info(train_query):
    try:
        if train_query.mock_im:
            response_data = _get_mock_im_response(train_query)
        else:
            response_data = _send_query(train_query, get_by_local_time=False)
    except ErrorResult:
        raise
    except Exception as e:
        log.exception('Неожиданная ошибка при получении данных с IM')
        raise ErrorResult(train_query, error=WorkerParseError(message=six.text_type(e)))

    try:
        segments = build_train_segments(response_data, train_query)
        segments = build_price_segments(segments, train_query)
        segments = fill_segment_fees(segments, train_query)  # Стандартную комиссию пишем в кэш и колдунщик.
    except ErrorResult:
        raise
    except Exception as e:
        log.exception('Неожиданная ошибка при разборе данных IM:')
        raise ErrorResult(train_query, error=WorkerParseError(message=six.text_type(e)))

    try:
        log_response(train_query, response_data)
    except Exception:
        log.exception('Ошибка при логировании ответа TrainPricing')

    if segments:
        return ImTariffsResult(train_query, segments=segments, status=ImTariffsResult.STATUS_SUCCESS)
    else:
        return ImTariffsResult(train_query, ImTariffsResult.STATUS_ERROR, error=WorkerEmptyResultError())


def _get_mock_im_response(train_query):
    try:
        return get_mock_im_train_pricing(train_query)
    except ImError as error:
        raise ErrorResult(train_query, error=error)


def make_title_dict(json_segment):
    title_dict = json_segment.get('titleDict', {})
    if title_dict:
        next_title_dict = {}
        if 'TitleParts' in title_dict:
            next_title_dict['title_parts'] = []
            for titlePart in title_dict['TitleParts']:
                if 'SettlementId' in titlePart:
                    next_title_dict['title_parts'].append('c' + str(titlePart['SettlementId']))
                else:
                    next_title_dict['title_parts'].append('s' + str(titlePart['StationId']))
        if 'Type' in title_dict:
            if title_dict['Type'] == thread_title_pb2.TThreadTitle.TYPE_DEFAULT:
                next_title_dict['type'] = 'default'
            if title_dict['Type'] == thread_title_pb2.TThreadTitle.TYPE_SUBURBAN:
                next_title_dict['type'] = 'suburban'
            if title_dict['Type'] == thread_title_pb2.TThreadTitle.TYPE_MTA:
                next_title_dict['type'] = 'mta'
        title_dict = next_title_dict

    return title_dict


def convert_json_to_train_segment(json_segment):
    segment = TrainSegment()
    segment.tariffs['classes'] = {}
    all_broken_classes = json_segment.get('brokenClasses', {})
    segment.tariffs['broken_classes'] = {key : all_broken_classes[key] for key in all_broken_classes
                                         if len(all_broken_classes[key]) != 0}
    for place in json_segment['places']:
        ticket_price_currency = place['priceDetails']['ticketPrice'].get('Currency', 'RUB')
        if ticket_price_currency == 'C_UNKNOWN':
            ticket_price_currency = 'RUB'
        service_price_currency = place['priceDetails']['servicePrice'].get('Currency', 'RUB')
        if service_price_currency == 'C_UNKNOWN':
            service_price_currency = 'RUB'
        tariff = TrainTariff(
            coach_type=CoachType(place['coachType']),

            ticket_price=Price(Decimal(place['priceDetails']['ticketPrice']['Amount']) /
                               10 ** Decimal(place['priceDetails']['ticketPrice']['Precision']),
                               ticket_price_currency),
            service_price=Price(Decimal(place['priceDetails']['servicePrice']['Amount']) /
                                10 ** Decimal(place['priceDetails']['servicePrice']['Precision']),
                                service_price_currency),
            seats=place.get('count', 0),
            lower_seats=place.get('lowerCount', 0),
            upper_seats=place.get('upperCount', 0),
            lower_side_seats=place.get('lowerSideCount', 0),
            upper_side_seats=place.get('upperSideCount', 0),
            max_seats_in_the_same_car=place.get('maxSeatsInTheSameCar', 0),
            several_prices=place.get('severalPrices', False),
            place_reservation_type=PlaceReservationType.USUAL,
            is_transit_document_required=False,
            availability_indication=AvailabilityIndication.AVAILABLE,
            service_class=place['serviceClass'],
            has_non_refundable_tariff=place.get('hasNonRefundableTariff', False),
        )
        tariff_fee_currency = place['priceDetails']['fee'].get('Currency', 'RUB')
        if tariff_fee_currency == 'C_UNKNOWN':
            tariff_fee_currency = 'RUB'
        tariff.fee = Price(place['priceDetails']['fee'].get('Amount', 0), tariff_fee_currency)
        segment.tariffs['classes'][place['coachType']] = tariff
        segment.tariffs['electronic_ticket'] = json_segment.get('electronicTicket', False)

    segment.arrival = datetime.strptime(json_segment['arrival'], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=pytz.utc)
    segment.departure = datetime.strptime(json_segment['departure'], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=pytz.utc)
    segment.provider = json_segment['provider']
    if segment.provider == "":
        segment.provider = None
    segment.has_dynamic_pricing = json_segment.get('hasDynamicPricing', False)
    segment.number = json_segment['number']
    segment.raw_train_name = json_segment.get('rawTrainName', None)
    if segment.raw_train_name == "":
        segment.raw_train_name = None
    segment.coach_owners = json_segment['coachOwners']
    segment.first_country_code = json_segment.get('firstCountryCode', None)
    if segment.first_country_code == "":
        segment.first_country_code = None
    segment.last_country_code = json_segment.get('lastCountryCode', None)
    if segment.last_country_code == "":
        segment.last_country_code = None
    segment.station_from_express_code = json_segment['departureStationId']
    segment.station_to_express_code = json_segment['arrivalStationId']
    segment.station_from = Station()
    segment.station_from.id = segment.station_from_express_code
    segment.station_to = Station()
    segment.station_to.id = segment.station_to_express_code
    segment.original_number = json_segment['number']
    segment.title_common = json.dumps(make_title_dict(json_segment))
    segment.thread = json_segment.get('hasThread', None)
    if segment.thread:
        add_thread_in_segment(segment)
    return segment


def add_thread_in_segment(segment):
    thread = WizardThreadInfo()
    segment.thread = thread
    thread.first_country_code = segment.first_country_code
    thread.last_country_code = segment.last_country_code
    thread.number = segment.number
    thread.deluxe_train = None


def _get_result_with_reasons_from_partner_and_add_schedule_info(train_query):
    if train_query.train_api_use_worker and settings.USE_WORKER:
        log.debug("Используем новый worker")
        channel = grpc.insecure_channel(settings.WORKER_ENDPOINT)
        interceptor = open_tracing_client_interceptor(global_tracer())
        channel = intercept_channel(channel, interceptor)

        stub = worker_service_pb2_grpc.WorkerBalancerServiceStub(channel)

        year, month, day = map(int, train_query.departure_date.strftime('%Y-%m-%d').split('-'))
        worker_request = worker_service_pb2.TSegmentsRequest(
            From=train_query.departure_point_code, To=train_query.arrival_point_code, Date=commons_pb2.TDate(
                Year=year, Month=month, Day=day),
            Header=worker_service_pb2.TRequestHeader(Priority=2, RequestId=""))

        message = stub.GetTariffs(worker_request)
        if message.Header.Code == commons_pb2.EC_NOT_FOUND:
            return ImTariffsResult(train_query, ImTariffsResult.STATUS_ERROR, error=WorkerEmptyResultError())
        if message.Header.Code == commons_pb2.EC_GENERAL_ERROR:
            log.exception('Неожиданная ошибка при обработке данных IM:')
            raise ErrorResult(train_query, error=WorkerParseError(message=six.text_type(message.Header.Error.Message)))
        try:
            json_segments = json.loads(MessageToJson(message, including_default_value_fields=True))
            log.debug("Сегменты из workera: %r", json_segments)
            segments = []
            for segment in json_segments['TariffTrain']:
                segments.append(convert_json_to_train_segment(segment))

            segments = fill_segment_fees(segments, train_query)
        except ErrorResult:
            raise
        except Exception as e:
            log.exception('Неожиданная ошибка при разборе данных IM:')
            raise ErrorResult(train_query, error=WorkerParseError(message=six.text_type(e)))
    else:
        search_segments = get_search_segments_with_keys(train_query)
        search_segments = fill_start_and_end_stations_from_thread(search_segments)
        if conf.TRAIN_PURCHASE_SKIP_IM_REQUEST_FOR_EMPTY_SEGMENTS and not search_segments:
            return ImTariffsResult(train_query, ImTariffsResult.STATUS_ERROR, error=WorkerEmptyResultError())
        try:
            if train_query.mock_im:
                response_data = _get_mock_im_response(train_query)
            else:
                response_data = _send_query(train_query, get_by_local_time=True)
            original_error = None
        except ErrorResult as e:
            if (
                e.result.error
                and (e.result.error.is_empty_result_error() or e.result.error.is_trains_not_found_error())
            ):
                response_data = None
                original_error = e
            else:
                raise
        except Exception as e:
            log.exception('Неожиданная ошибка при получении данных с IM')
            raise ErrorResult(train_query, error=WorkerParseError(message=six.text_type(e)))

        try:
            segments = build_train_segments(response_data, train_query)
            segments = build_price_segments_with_reason_for_missing_prices(segments, search_segments, train_query)
            segments = fill_segment_fees(segments, train_query)  # Стандартную комиссию пишем в кэш и колдунщик.
        except ErrorResult:
            raise
        except Exception as e:
            log.exception('Неожиданная ошибка при разборе данных IM:')
            raise ErrorResult(train_query, error=WorkerParseError(message=six.text_type(e)))

        try:
            log_response(train_query, response_data)
        except Exception:
            log.exception('Ошибка при логировании ответа TrainPricing')

        if not segments and original_error:
            raise original_error

    if segments:
        return ImTariffsResult(train_query, segments=segments, status=ImTariffsResult.STATUS_SUCCESS)
    else:
        return ImTariffsResult(train_query, ImTariffsResult.STATUS_ERROR, error=WorkerEmptyResultError())


@measurable(endpoint_name='train_tariffs')
def _send_query(train_query, get_by_local_time):
    params = {
        'Origin': train_query.departure_point_code,
        'Destination': train_query.arrival_point_code,
        'DepartureDate': train_query.departure_date.strftime('%Y-%m-%dT00:00:00'),
        'CarGrouping': 'DontGroup',
        'GetByLocalTime': get_by_local_time,
    }

    log.info('Спрашиваем TrainPricing IM. Request params: %r', params)
    start = os_time.time()
    rps_limiter.save_query_token(RPS_LIMITER_IM_SEARCH_KEY)
    try:
        response_data = get_im_response(TRAIN_PRICING_ENDPOINT, params)
    except ImError as error:
        raise ErrorResult(train_query, error=error)
    except Exception as last_exception:
        log.exception('Не удалось получить ответ от IM')
        raise ErrorResult(train_query, error=WorkerNetworkError(six.text_type(last_exception)))
    else:
        log.info('Получили ответ за %.3f', os_time.time() - start)
        return response_data


def postprocess_query_result(result, train_query, yandex_uid=None):
    if result.segments:
        result.segments = fill_segment_fees(result.segments, train_query, yandex_uid=yandex_uid)
    return result
