# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import operator
from concurrent.futures import Future
from datetime import timedelta, date, datetime, time  # noqa: UnusedImport
from contextlib import closing
from logging import Logger, getLogger  # noqa: UnusedImport

import pytz
from sqlalchemy import sql
from typing import Tuple, List  # noqa: UnusedImport

import common.utils.railway as railway_utils
from common.models.currency import Price
from common.models_utils.geo import Point  # noqa: UnusedImport
from common.utils.date import daterange
from travel.library.python.tracing.instrumentation import traced_function
from travel.rasp.wizards.train_wizard_api.lib.express_system_provider import ExpressSystemProvider, express_system_provider  # noqa: UnusedImport
from travel.rasp.wizards.train_wizard_api.lib.pgaas_price_store.db_models import DirectionTariffInfo
from travel.rasp.wizards.train_wizard_api.lib.pgaas_price_store.models.place import Place
from travel.rasp.wizards.train_wizard_api.lib.pgaas_price_store.models.tariff_direction_info import TariffDirectionInfo
from travel.rasp.wizards.train_wizard_api.lib.pgaas_price_store.models.tariff_direction_updated_info import (
    TariffDirectionUpdatedInfoRecord, TariffDirectionUpdatedInfo
)
from travel.rasp.wizards.train_wizard_api.lib.storage_store import StorageStore, storage_store  # noqa: UnusedImport
from travel.rasp.wizards.train_wizard_api.lib.storage_timed_execute import ExecutionTimeout, execute_with_timeout, get_future


class TariffDirectionInfoSource(object):
    def __init__(self, storage_store):
        # type: (StorageStore) -> None
        self._storage_store = storage_store

    @traced_function(name='train_wizard_api.lib.pgaas_price_store.tariff_direction_info_provider.TariffDirectionInfoSource.find')
    def find(self, departure_point_express_id, arrival_point_express_id, left_border, right_border):
        # type: (int, int, date, date) -> List[Tuple[List[dict], date, datetime]]
        storage = self._storage_store.get('slave')
        query = sql.select([
            DirectionTariffInfo.data,
            DirectionTariffInfo.departure_date,
            DirectionTariffInfo.updated_at
        ], whereclause=sql.and_(
            DirectionTariffInfo.departure_point_express_id == departure_point_express_id,
            DirectionTariffInfo.arrival_point_express_id == arrival_point_express_id,
            DirectionTariffInfo.departure_date >= left_border,
            DirectionTariffInfo.departure_date <= right_border,
        ))

        return get_future(storage, query)

    @traced_function(name='train_wizard_api.lib.pgaas_price_store.tariff_direction_info_provider.TariffDirectionInfoSource.find_tariffs_by_directions')
    def find_tariffs_by_directions(self, directions, left_border, right_border):
        # type: (List[Tuple[int, int]], date, date) -> List[Tuple[List[dict], date, datetime]]
        direction_constraints = [
            sql.and_(
                DirectionTariffInfo.departure_point_express_id == departure_express_id,
                DirectionTariffInfo.arrival_point_express_id == arrival_express_id
            )
            for departure_express_id, arrival_express_id in directions
        ]
        with closing(self._storage_store.get('slave').get_session()) as session:
            return list(session.query(
                DirectionTariffInfo.data,
                DirectionTariffInfo.departure_date,
                DirectionTariffInfo.updated_at
            ).filter(
                sql.or_(*direction_constraints),
                DirectionTariffInfo.departure_date >= left_border.strftime('%Y-%m-%d'),
                DirectionTariffInfo.departure_date <= right_border.strftime('%Y-%m-%d'),
            ))


class TariffDirectionInfoProvider(object):
    def __init__(self, tariff_direction_info_source, express_system_provider, railway_utils, logger):
        # type: (TariffDirectionInfoSource, ExpressSystemProvider, any, Logger) -> None
        self._tariff_direction_info_source = tariff_direction_info_source
        self._express_system_provider = express_system_provider
        self._railway_utils = railway_utils
        self._logger = logger

    @traced_function(name='train_wizard_api.lib.pgaas_price_store.tariff_direction_info_provider.TariffDirectionInfoProvider.find')
    def find(self, departure_point, arrival_point, departure_date, days=1):
        # type: (Point, Point, date) -> Tuple[Tuple[TariffDirectionInfo, ...], TariffDirectionUpdatedInfo]

        departure_point_key = departure_point.point_key
        arrival_point_key = arrival_point.point_key

        departure_point_express_id = self._express_system_provider.find_express_id(departure_point_key)
        arrival_point_express_id = self._express_system_provider.find_express_id(arrival_point_key)

        if not departure_point_express_id or not arrival_point_express_id:
            self._logger.warn(
                'Can not find prices by [%s-%s], because can not find express codes for one of points [%s-%s]',
                departure_point_express_id, arrival_point_express_id,
                departure_point_key, arrival_point_key,
            )
            result = Future()
            result.set_result(())
            return result, {}
        self._logger.info(
            'Start search: [%s(%s)-%s(%s)-%s-%s]',
            departure_point_key, departure_point_express_id, arrival_point_key,
            arrival_point_express_id, departure_date, days
        )

        max_departure_date = departure_date + timedelta(days=days)
        left_border = departure_date - timedelta(days=1)
        right_border = max_departure_date

        context = {
            'max_departure_date': max_departure_date,
            'departure_date': departure_date,
            'departure_point': departure_point,
            'left_border': left_border,
            'right_border': right_border,
        }
        return (
            self._tariff_direction_info_source.find(
                departure_point_express_id=departure_point_express_id,
                arrival_point_express_id=arrival_point_express_id,
                left_border=left_border,
                right_border=right_border,
            ),
            context
        )

    def _build_info_with_documents(self, documents, context):
        if not context:
            return (), TariffDirectionUpdatedInfo(())
        return self._build_info(
            documents=documents,
            min_departure_date=context.get('departure_date'),
            max_departure_date=context.get('max_departure_date')
        ), self._build_updated_info(
            documents=documents,
            departure_point=context.get('departure_point'),
            left_border=context.get('left_border'),
            right_border=context.get('right_border')
        )

    def build_info(self, rows, context):
        self._logger.error('TariffDirectionInfoProvider build_info')
        return self._build_info_with_documents(list(rows), context)

    def build_empty_info(self, context):
        self._logger.error('TariffDirectionInfoProvider build_empty_info')
        return self._build_info_with_documents([], context)

    @traced_function(name='train_wizard_api.lib.pgaas_price_store.tariff_direction_info_provider.TariffDirectionInfoProvider._build_updated_info')
    def _build_updated_info(self, documents, departure_point, left_border, right_border):
        railway_timezone = self._railway_utils.get_railway_tz_by_point(departure_point)
        departure_date_to_updated_at = {
            departure_date: pytz.UTC.localize(updated_at) for _, departure_date, updated_at in documents
        }

        return TariffDirectionUpdatedInfo(
            records=tuple(sorted(
                (
                    TariffDirectionUpdatedInfoRecord(
                        left_border=railway_timezone.localize(datetime.combine(departure_date, time.min)),
                        right_border=railway_timezone.localize(datetime.combine(departure_date, time.max)),
                        updated_at=departure_date_to_updated_at.get(departure_date)
                    )
                    for departure_date in daterange(left_border, right_border, include_end=True)
                ),
                key=operator.attrgetter('left_border')
            ))
        )

    @traced_function(name='train_wizard_api.lib.pgaas_price_store.tariff_direction_info_provider.TariffDirectionInfoProvider._build_info')
    def _build_info(self, documents, min_departure_date, max_departure_date):
        min_departure_date_str = min_departure_date.strftime('%Y-%m-%d')
        max_departure_date_str = max_departure_date.strftime('%Y-%m-%d')
        return tuple(
            TariffDirectionInfo(
                arrival_dt=raw_info['arrival_dt'],
                arrival_station_id=raw_info['arrival_station_id'],
                departure_dt=raw_info['departure_dt'],
                departure_station_id=raw_info['departure_station_id'],
                number=raw_info['number'],
                display_number=raw_info.get('display_number'),
                title_dict=raw_info['title_dict'],
                electronic_ticket=bool(raw_info.get('electronic_ticket')),
                places=None if raw_info['places'] is None else tuple(Place(
                    coach_type=p['coach_type'],
                    count=p['count'],
                    max_seats_in_the_same_car=p.get('max_seats_in_the_same_car', p['count']),
                    price=Price(float(p['price']['value']), p['price']['currency']),
                    price_details=p.get('price_details'),
                    service_class=p.get('service_class'),
                ) for p in raw_info['places']),
                broken_classes=raw_info.get('broken_classes'),
                coach_owners=raw_info.get('coach_owners', []),
                has_dynamic_pricing=bool(raw_info.get('has_dynamic_pricing')),
                two_storey=bool(raw_info.get('two_storey')),
                is_suburban=bool(raw_info.get('is_suburban')),
                first_country_code=raw_info.get('first_country_code'),
                last_country_code=raw_info.get('last_country_code'),
                provider=raw_info.get('provider'),
                raw_train_name=raw_info.get('raw_train_name'),
            ) for raw_info in sorted(
                (
                    info
                    for doc in documents
                    for info in doc[0]
                    if min_departure_date_str < info['departure_dt'] < max_departure_date_str
                ),
                key=operator.itemgetter('departure_dt')
            )
        )

    @traced_function(name='train_wizard_api.lib.pgaas_price_store.tariff_direction_info_provider.TariffDirectionInfoProvider.find_tariffs_by_directions')
    def find_tariffs_by_directions(self, directions, min_departure_date, max_departure_date):
        # type: (List[Tuple[Point, Point]], date, date) -> Tuple[TariffDirectionInfo, ...]

        self._logger.info('Start search for directions: [%s]', [(d.point_key, a.point_key) for d, a in directions])

        express_id_directions = []
        for departure_point, arrival_point in directions:
            departure_point_key = departure_point.point_key
            arrival_point_key = arrival_point.point_key

            departure_point_express_id = self._express_system_provider.find_express_id(departure_point_key)
            arrival_point_express_id = self._express_system_provider.find_express_id(arrival_point_key)

            if not departure_point_express_id or not arrival_point_express_id:
                self._logger.warn(
                    'Skipping prices for [%s-%s], because can not find express codes for one of points [%s-%s]',
                    departure_point_express_id, arrival_point_express_id,
                    departure_point_key, arrival_point_key,
                )
                continue
            express_id_directions.append((departure_point_express_id, arrival_point_express_id))

        if not express_id_directions:
            return ()

        documents = self._tariff_direction_info_source.find_tariffs_by_directions(
            express_id_directions, min_departure_date, max_departure_date
        )
        return self._build_info(documents, min_departure_date - timedelta(1), max_departure_date + timedelta(1))


tariff_direction_info_source = TariffDirectionInfoSource(
    storage_store=storage_store,
)
tariff_direction_info_provider = TariffDirectionInfoProvider(
    tariff_direction_info_source=tariff_direction_info_source,
    express_system_provider=express_system_provider,
    railway_utils=railway_utils,
    logger=getLogger(__name__)
)
