# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import json
import logging
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, date, timedelta
from functools import partial
from itertools import chain
from typing import AnyStr, Dict, Iterable, List, Optional, Tuple

from travel.rasp.library.python.api_clients.krasnodar_bus import KrasnodarBusClient
from travel.rasp.library.python.ydb import ensure_table_exists, session_pool_context
from travel.rasp.rasp_scripts.scripts.schedule.krasnodar.availiable_avs import AVAILIABLE_AVS, AVAILIABLE_AVS_DICT
from travel.rasp.rasp_scripts.scripts.schedule.krasnodar.krasnodar_bus_parser import KrasnodarBusParser
from travel.rasp.rasp_scripts.scripts.schedule.krasnodar.krasnodar_cache import KrasnodarCache
from travel.rasp.rasp_scripts.scripts.schedule.krasnodar.models import Destination, Passage
from travel.rasp.rasp_scripts.scripts.schedule.krasnodar.storages import DestinationStorage, PassageStorage
from travel.rasp.rasp_scripts.scripts.schedule.krasnodar.utils import get_hash_key, expire_at, get_cache_key

from common.data_api.ydb.instance import ydb_driver_config
from common.db.mds.clients import mds_s3_common_client
from common.cysix.builder import (
    CarrierBlock, ChannelBlock, GroupBlock, ScheduleBlock, StationBlock, StoppointBlock, ThreadBlock, VehicleBlock
)
from rasp_vault.api import get_secret

from travel.rasp.library.python.common23.date.environment import today


KRASNODAR_BUS_API_HOST = 'http://83.239.56.140:8080/'
KRASNODAR_BUS_TIMEOUT = 3
MDS_PATH = 'schedule/bus/krasnodar.xml'

log = logging.getLogger()


class KrasnodarBusBuilder(object):
    def __init__(self, destination_storage, passage_storage):
        # type: (DestinationStorage, PassageStorage) -> KrasnodarBusBuilder
        self._destination_storage = destination_storage
        self._passage_storage = passage_storage
        self._stations = {}
        self._vehicles = {}
        self._carriers = {}
        self._threads = {}

    def build_schedule(self, dt, days_count):
        # type: (datetime, int) -> ChannelBlock

        channel_block = ChannelBlock(
            t_type='bus',
            station_code_system='vendor',
            carrier_code_system='vendor'
        )
        group_block = GroupBlock(channel_block, title='Расписание Краснодарского автовокзала', code='1')
        channel_block.add_group_block(group_block)

        for av in AVAILIABLE_AVS:
            self.build_av(group_block, dt, av, days_count)

        return channel_block

    def build_av(self, group_block, now, kodp, days_count):
        # type: (GroupBlock, date, Destination, int) -> None

        destinations = self._destination_storage.get_response(kodp)
        if not destinations:
            return

        builded_threads = {}
        destinations_by_id = {}
        for destination in destinations:
            destinations_by_id[destination.id] = destination

        for days_shift in range(days_count):
            dt = now + timedelta(days=days_shift)

            passages_by_id = defaultdict(list)
            for destination in destinations:
                passages = self._passage_storage.get_response(kodp, destination, dt)
                if passages is None:
                    continue

                for passage in passages:
                    passages_by_id[passage.id].append(passage)

            for passage_id, passages in passages_by_id.items():
                builded_thread = self.build_thread(group_block, passage_id, passages, kodp, destinations_by_id)
                if builded_thread is not None:
                    builded_threads[builded_thread.number] = builded_thread

        if not len(builded_threads):
            return

        for thread_block in builded_threads.values():
            group_block.add_thread_block(thread_block)

    def build_station(self, destination, group_block):
        # type: (Destination, GroupBlock) -> StationBlock

        station_block = self._stations.get(destination.id)
        if station_block is not None:
            return station_block
        station_block = StationBlock(
            group=group_block,
            title=destination.name,
            code=destination.id
        )
        station_block.country_code = 'RU'
        self._stations[destination.id] = station_block
        return station_block

    def get_middle_segment(self, passage_id, dt, kodp, kpp):
        # type: (AnyStr, datetime, Destination, Destination) -> Optional[Passage]

        segment_reises = self._passage_storage.get_response(kodp, kpp, dt)
        if segment_reises is None:
            return None
        for segment_reis in segment_reises:
            if segment_reis.id == passage_id:
                return segment_reis

    def build_thread(self, group_block, passage_id, passages, kodp, destinations_by_id):
        # type: (GroupBlock, AnyStr, List[Passage], Destination, Dict) -> Optional[ThreadBlock]

        thread_name = passages[0].name
        thread_number = passages[0].id

        if not thread_name.startswith(kodp.name + '-'):
            return

        reises_from_first_station = [passage for passage in passages if passage.search_dest_from == kodp.id]
        reises_from_first_station.sort(key=lambda x: x.arrival)
        thread_departure = reises_from_first_station[0].departure

        start_station = self.build_station(kodp, group_block)
        group_block.add_station_block(start_station)

        vehicle_block = self.build_vehicle(passages[0], group_block)
        group_block.add_vehicle_block(vehicle_block)

        carrier_block = self.build_carrier(passages[0], group_block)
        group_block.add_carrier_block(carrier_block)

        thread_block = self._threads.get(passage_id)

        if thread_block is None:
            thread_block = ThreadBlock(
                group=group_block,
                title=thread_name,
                number=thread_number,
                vehicle=vehicle_block,
                carrier=carrier_block
            )

            stopppoints = [
                StoppointBlock(
                    thread=thread_block,
                    station=start_station,
                    departure_time=thread_departure.strftime('%H:%M'),
                )
            ]

            for reis in reises_from_first_station:
                arrival = reis.arrival
                departure = (reis.arrival + timedelta(minutes=1))
                station = destinations_by_id[reis.search_dest_to]
                station_block = self.build_station(station, group_block)
                group_block.add_station_block(station_block)
                stoppoint = StoppointBlock(
                    thread=thread_block, station=station_block,
                    arrival_time=arrival.strftime('%H:%M'), departure_time=departure.strftime('%H:%M')
                )
                stopppoints.append(stoppoint)

            stopppoints[-1].departure_time = None
            for stoppoint in stopppoints:
                thread_block.add_stoppoint_block(stoppoint)

            last_destination = destinations_by_id[reises_from_first_station[-1].search_dest_to]
            for i, reis in enumerate(reises_from_first_station[:-1]):
                if reis.search_dest_to in AVAILIABLE_AVS_DICT:
                    first_destination = AVAILIABLE_AVS_DICT[reis.search_dest_to]
                    segment = self.get_middle_segment(
                        passage_id, thread_departure.date(), first_destination, last_destination
                    )
                    if segment is not None:
                        stopppoints[i+1].departure_time = segment.departure.strftime('%H:%M')

            self._threads[passage_id] = thread_block

        schedule_str = thread_departure.strftime('%Y-%m-%d')
        if thread_block._schedules:
            schedule_block = thread_block._schedules[0]
            schedule_block.days += ';{}'.format(schedule_str)
        else:
            schedule_block = ScheduleBlock(thread_block, schedule_str)
            thread_block.add_schedule_block(schedule_block)

        return thread_block

    def build_vehicle(self, reis, group_block):
        # type: (Passage, GroupBlock) -> VehicleBlock

        vehicle_title = reis.busmark
        vehicle_code = get_hash_key(vehicle_title)
        vehicle_block = self._vehicles.get(vehicle_code)

        if vehicle_block is None:
            vehicle_block = VehicleBlock(
                group=group_block,
                title=vehicle_title,
                code=vehicle_code
            )
            self._vehicles[vehicle_code] = vehicle_block
        return vehicle_block

    def build_carrier(self, reis, group_block):
        # type: (Passage, GroupBlock) -> CarrierBlock

        carrier_title = reis.atp
        carrier_code = get_hash_key(carrier_title)
        carrier_block = self._carriers.get(carrier_code)

        if carrier_block is None:
            carrier_block = CarrierBlock(
                group=group_block,
                title=carrier_title,
                code=carrier_code
            )
            self._carriers[carrier_title] = carrier_block
        return carrier_block


def write_result(xml, mds_path, is_local):
    if is_local:
        with open('res.xml', 'wb') as f:
            f.write(xml.encode('utf-8'))
    else:
        mds_s3_common_client.save_data(mds_path, xml)


def collect_av_destinations(args):
    # type: (Tuple[KrasnodarBusParser, Destination]) -> Tuple[Destination, List[Destination]]
    parser, av = args

    destinations = parser.get_destinations(av)
    return av, destinations


def collect_av_passages(parser, args):
    # type: (KrasnodarBusParser, Tuple[Destination, Destination, datetime]) -> Tuple[Destination, Destination, datetime, List[Passage]]
    av, destination, dt = args

    passages = parser.get_passages(dt, av, destination)
    return av, destination, dt, passages


def collect_av_passages_from_cache(cache, args):
    # type: (KrasnodarCache, Tuple[Destination, Destination, datetime]) -> Tuple[Destination, Destination, datetime, List[Passage]]
    av, destination, dt = args

    row = cache.get(get_cache_key(dt, av, destination))
    if not row:
        return av, destination, dt, []

    passages = json.loads(row.value)
    passages = [Passage(**passage) for passage in passages]

    return av, destination, dt, passages


def add_to_cache(cache, av, destination, dt, passages):
    # type: (KrasnodarCache, Destination, Destination, datetime, List[Passage]) -> None

    cache.add([{
        'key': get_cache_key(dt, av, destination),
        'value': json.dumps([p.to_dict() for p in passages]),
        'expire_at': expire_at(dt)
    }])


def build_collect_av_passages_args(destination_storage, now, days):
    # type: (DestinationStorage, datetime, List[int]) -> Iterable[Tuple[Destination, Destination, datetime]]
    for av in AVAILIABLE_AVS:
        destinations = destination_storage.get_response(av)
        if destinations is None:
            continue
        for destination in destinations:
            for days_shift in days:
                dt = now + timedelta(days=days_shift)
                yield av, destination, dt


def collect_data(parser, cache, now, days_count, days_shift_from, days_shift_to, pool_size, is_local):
    # type: (KrasnodarBusParser, KrasnodarCache, date, int, int, int, int, bool) -> Tuple[DestinationStorage, PassageStorage]
    destination_storage_by_id = {}
    destination_response_cache = {}
    passage_response_cache = {}

    destination_storage = DestinationStorage(destination_storage_by_id, destination_response_cache)
    passage_storage = PassageStorage(passage_response_cache)

    load_days = list(set(chain(range(days_shift_from, days_shift_to), [days_count - 1])))
    cache_days = list(set(range(days_count)) - set(load_days))

    if is_local:
        destination_storage.load()
        passage_storage.load()
    else:
        with ThreadPoolExecutor(max_workers=pool_size) as executor:
            for av, destinations in executor.map(
                collect_av_destinations,
                [(parser, av) for av in AVAILIABLE_AVS]
            ):
                destination_storage.add_response(av, destinations)

        with ThreadPoolExecutor(max_workers=pool_size) as executor:
            for key_av, key_destination, key_dt, value_passages in executor.map(
                partial(collect_av_passages, parser),
                build_collect_av_passages_args(destination_storage, now, load_days)
            ):
                passage_storage.add_response(key_av, key_destination, key_dt, value_passages)
                add_to_cache(cache, key_av, key_destination, key_dt, value_passages)

        for key_av, key_destination, key_dt, value_passages in map(
            partial(collect_av_passages_from_cache, cache),
            build_collect_av_passages_args(destination_storage, now, cache_days)
        ):
            passage_storage.add_response(key_av, key_destination, key_dt, value_passages)

        destination_storage.dump()
        passage_storage.dump()

    return destination_storage, passage_storage


def run_with_params(
    krasnodar_login,
    krasnodar_password,
    krasnodar_host,
    days_count,
    days_shift_from,
    days_shift_to,
    pool_size,
    mds_path,
    is_local,
    is_local_result
):
    now = today()

    with session_pool_context(ydb_driver_config) as context:
        cache = KrasnodarCache(context)
        ensure_table_exists(cache)

        client = KrasnodarBusClient(host=krasnodar_host, login=krasnodar_login, password=krasnodar_password, timeout=KRASNODAR_BUS_TIMEOUT)
        parser = KrasnodarBusParser(client, AVAILIABLE_AVS_DICT)

        destination_storage, passage_storage = collect_data(
            parser,
            cache,
            now,
            days_count,
            days_shift_from,
            days_shift_to,
            pool_size,
            is_local
        )

        builder = KrasnodarBusBuilder(destination_storage, passage_storage)
        schedule = builder.build_schedule(now, days_count)
        res_xml = schedule.to_unicode_xml()
        write_result(res_xml, mds_path, is_local_result)


def run(**kwargs):
    login = get_secret('rasp-common.RASP_KRASNODAR_BUS_LOGIN')
    password = get_secret('rasp-common.RASP_KRASNODAR_BUS_PASSWORD')

    days_count = kwargs.get('days_count', 2)
    days_shift_from = kwargs.get('days_shift_from', 1)
    days_shift_to = kwargs.get('days_shift_to', 1)
    pool_size = kwargs.get('pool_size', 4)
    host = kwargs.get('host', KRASNODAR_BUS_API_HOST)
    mds_path = kwargs.get('mds_path', MDS_PATH)
    is_local = os.getenv('RASP_KRASNODAR_LOCAL_RUN', False)  # Для локального кэширования результатов хождения в апи
    is_local_result = os.getenv('RASP_KRASNODAR_LOCAL_RESULT', False)  # Для локального сохранения результатов
    is_local_result = is_local or is_local_result

    run_with_params(
        krasnodar_login=login,
        krasnodar_password=password,
        krasnodar_host=host,
        days_count=days_count,
        days_shift_from=days_shift_from,
        days_shift_to=days_shift_to,
        pool_size=pool_size,
        mds_path=mds_path,
        is_local=is_local,
        is_local_result=is_local_result
    )
