# -*- coding: utf-8 -*-

from __future__ import absolute_import

import cPickle
import logging
import operator
from collections import defaultdict
from itertools import chain, groupby, product

from travel.avia.library.python.common.models.schedule import DeLuxeTrain, RThreadType
from travel.avia.library.python.common.models.tariffs import AeroexTariff, ThreadTariff
from travel.avia.library.python.common.models.transport import TransportType
from travel.avia.library.python.common.utils.date import RunMask
from travel.avia.library.python.route_search.models import ZNodeRoute2

from travel.avia.admin.precalc.utils import (
    iter_slices, map_groups_forked, normalize_string, replace_water_t_type
)
from travel.avia.admin.precalc.utils.db import execute_with_conditions, int_time, pack_run_mask
from travel.avia.admin.precalc.utils.originalbox import OriginalBox
from travel.avia.admin.precalc.utils.route_number import flatten_number_key, generate_number_keys


log = logging.getLogger('precalc')


def prefetch_suburban_tariffs():
    log.info('Fetching suburban tariffs...')

    suburban_tariffs = {}

    records = AeroexTariff.objects \
        .prefetch_related('type', 'station_from__country__currency') \
        .order_by('-precalc')  # Чтобы ручные забивали предрасчитанные

    for t in records:
        if t.type.category != 'usual':
            continue

        country = t.station_from.country
        price = (t.tariff, country and country.currency and country.currency.code)

        suburban_tariffs.setdefault((t.station_from_id, t.station_to_id), {})[t.type.code] = price

        if t.reverse:
            suburban_tariffs.setdefault((t.station_to_id, t.station_from_id), {})[t.type.code] = price

    return suburban_tariffs


ThreadTariffBox = OriginalBox(
    'station_from_id', 'station_to_id', 'thread_uid',
    'year_days', 'tariff', 'currency'
)


def prefetch_thread_tariffs():
    log.info('Fetching thread tariffs...')

    thread_tariffs = {}

    for t in ThreadTariffBox.iter_queryset(ThreadTariff.objects.all()):
        thread_tariffs.setdefault(
            (t.station_from_id, t.station_to_id, t.thread_uid),
            []
        ).append((
            pack_run_mask(t.year_days) if t.year_days else None,
            (t.tariff, t.currency)
        ))

    return thread_tariffs


StationBox = OriginalBox('majority_id')

RTStationBox = OriginalBox('tz_arrival', 'tz_departure', 'time_zone')

ZNodeRouteBox = OriginalBox(
    'id',

    'route__hidden',

    'thread' | OriginalBox(
        'uid', 'tz_start_time', 'type_id', 'year_days', 't_type_id', 'number',
        'tariff_type__code', 'express_type'
    ),
    'thread_id',

    'settlement_from_id',
    'station_from' | StationBox,
    'station_from_id',
    'rtstation_from' | RTStationBox,
    'rtstation_from_id',

    'settlement_to_id',
    'station_to' | StationBox,
    'station_to_id',
    'rtstation_to' | RTStationBox,
    'rtstation_to_id',
)


class Segment(object):

    def __init__(self, record):
        start_time = int_time(record.thread.tz_start_time)

        self.mask_shift, self.dep_time = \
            divmod(start_time + record.rtstation_from.tz_departure * 60, 86400)
        _mask_shift, self.arr_time = \
            divmod(start_time + record.rtstation_to.tz_arrival * 60, 86400)

        self.record = record

POINT_TYPES = {
    'c': 'settlement',
    's': 'station'
}


def search_rows(key, segment, data):
    record = segment.record
    thread = record.thread

    title_special = None
    if thread.t_type_id == TransportType.TRAIN_ID:
        deluxe_train = DeLuxeTrain.get_by_number(thread.number)
        if deluxe_train:
            title_special = deluxe_train.title

    number_keys = generate_number_keys(thread.t_type_id, thread.number)

    try:
        first_key = number_keys.next()
    except StopIteration:
        first_key = None

    common_rows = (
        segment.dep_time,
        segment.mask_shift,
        record.thread_id,
        record.rtstation_from_id,
        record.rtstation_to_id,
        buffer(cPickle.dumps(data, cPickle.HIGHEST_PROTOCOL)),
    )

    yield (
        False,
        key,
        replace_water_t_type(thread.t_type_id),
        normalize_string(title_special),
        bool(thread.express_type),
        flatten_number_key(first_key)
    ) + common_rows

    for number_key in number_keys:
        yield (
            True,
            key,
            replace_water_t_type(thread.t_type_id),
            None,
            None,
            flatten_number_key(number_key)
        ) + common_rows


def remove_duplicates(segments):
    group_key = operator.attrgetter('record.thread_id')
    segments.sort(key=group_key)

    for _thread_id, segments in groupby(segments, key=group_key):
        yield min(segments, key=lambda segment: (
            segment.record.station_from.majority_id,
            segment.record.station_to.majority_id,
            segment.record.rtstation_from.tz_arrival is not None,
            -(segment.record.rtstation_from.tz_departure or 0),
            segment.record.rtstation_to.tz_arrival,
        ))


def remove_through_trains(segments):
    """
    Удаляем безпересадочные нитки, у которых в поиск попала основная нитка.
    Предполагаем, что нитки импортированы с одинаковыми временными зонами.
    """
    group_key = operator.attrgetter(
        'record.thread.t_type_id',
        'record.station_from_id', 'dep_time',
        'record.station_to_id', 'arr_time',
    )
    segments.sort(key=group_key)

    for key, segments in groupby(segments, key=group_key):
        t_type_id = key[0]

        segments = list(segments)

        if t_type_id == TransportType.TRAIN_ID and len(segments) > 1:
            basic_segments = [
                segment
                for segment in segments
                if segment.record.thread.type_id == RThreadType.BASIC_ID
            ]

            if basic_segments:
                segments = basic_segments

        for segment in segments:
            yield segment


def precalc_search(connect, precalc_state):
    records = ZNodeRoute2.objects.all()

    conn = connect()

    record_ids = None

    if precalc_state.partial:
        threads_deleted, threads_added, threads_changed = (
            precalc_state.threads_deleted,
            precalc_state.threads_added,
            precalc_state.threads_changed
        )

        if not (threads_deleted or threads_added or threads_changed):
            log.info('Nothing to precalc')
            return

        log.info('Searching modified records...')

        record_ids = set()
        search_keys = set()

        # выбираем направления, в которые раньше входили измененные нитки поездов
        train_search_keys = set(
            search_key
            for thread_ids in iter_slices((threads_deleted | threads_changed),
                                          500)
            for (search_key,) in execute_with_conditions(conn, """
                SELECT DISTINCT key FROM search WHERE {}
            """, {
                'thread_id IN ?': thread_ids,
                't_type_id = ?': TransportType.TRAIN_ID
            })
        )

        for record in records.filter(
            thread__id__in=(threads_added | threads_changed)
        ).select_related('thread'):
            record_ids.add(record.id)

            for key_prefixes in product(['s', 'c'], repeat=2):
                point_ids = [
                    getattr(
                        record,
                        '{}_{}_id'.format(POINT_TYPES[key_prefix], direction)
                    )
                    for key_prefix, direction in zip(key_prefixes,
                                                     ('from', 'to'))
                ]

                if not all(point_ids):
                    continue

                search_key = '-'.join(
                    '{}{}'.format(key_prefix, point_id)
                    for key_prefix, point_id in zip(key_prefixes, point_ids)
                )

                # выбираем направления, в которые сейчас входят измененные нитки поездов
                if record.thread.t_type_id == TransportType.TRAIN_ID:
                    train_search_keys.add(search_key)

                search_keys.add(search_key)

        log.info('Fetching train directions record ids...')

        train_records = records.filter(thread__t_type_id=TransportType.TRAIN_ID)
        record_ids.update(chain.from_iterable(
            train_records.filter(**{
                '{}_{}_id'.format(POINT_TYPES[key[0]], direction): int(key[1:])
                for direction, key in zip(('from', 'to'), search_key.split('-'))
            }).values_list('id', flat=True)
            for search_key in train_search_keys
        ))

        log.info('Deleting stale records...')

        for thread_ids in iter_slices((threads_deleted | threads_changed), 500):
            execute_with_conditions(conn, """
                DELETE FROM search WHERE {}
            """, {
                'thread_id IN ?': thread_ids,
                't_type_id != ?': TransportType.TRAIN_ID
            })

        for search_keys_part in iter_slices(train_search_keys, 500):
            execute_with_conditions(conn, """
                DELETE FROM search WHERE {}
            """, {
                'key IN ?': search_keys_part,
                't_type_id = ?': TransportType.TRAIN_ID
            })

        precalc_state.searches_changed = search_keys

    else:
        conn.executescript("""
            DROP TABLE IF EXISTS search;

            CREATE TABLE search (
                id INTEGER PRIMARY KEY,
                is_extra INTEGER NOT NULL,
                key TEXT NOT NULL,
                t_type_id INTEGER NOT NULL,
                brand TEXT,
                is_express INTEGER,
                number_part TEXT,
                dep_time INTEGER NOT NULL,
                dep_day_shift INTEGER NOT NULL,
                thread_id INTEGER NOT NULL,
                dep_stop_id INTEGER NOT NULL,
                arr_stop_id INTEGER NOT NULL,
                data BLOB NOT NULL
            );
        """)

    conn.commit()

    suburban_tariffs = prefetch_suburban_tariffs()
    thread_tariffs = prefetch_thread_tariffs()

    def extra_data(record):
        thread = record.thread

        mask_tariffs = []

        if thread.t_type_id == TransportType.SUBURBAN_ID:
            tariffs = suburban_tariffs.get((record.station_from_id, record.station_to_id), {})
            tariff = tariffs.get(
                thread.tariff_type__code or (
                    'express'
                    if thread.express_type in ('express', 'aeroexpress') else
                    'etrain'
                )
            )

            if tariff:
                mask_tariffs.append((None, tariff))

        tariffs = thread_tariffs.get((record.station_from_id, record.station_to_id, thread.uid))

        if tariffs:
            mask_tariffs.extend(tariffs)

        return {
            'mask_tariffs': mask_tariffs,
            'dep_time_zone': record.rtstation_from.time_zone
        }

    def process_bin(records, thread_keys, do_remove_duplicates):
        noderoute = defaultdict(list)

        for record in records:
            thread = record.thread

            if (thread.type_id == RThreadType.CANCEL_ID or
                    record.route__hidden or
                    thread.year_days == RunMask.EMPTY_YEAR_DAYS):
                continue

            try:
                segment = Segment(record)
            except Exception:
                log.exception("Skipping %r", record.id)
                continue

            for key in thread_keys(record):
                noderoute[key].append(segment)

        for key, segments in noderoute.items():
            if do_remove_duplicates:
                # Между городами убираем дубликаты, между станциями нет
                segments = list(remove_duplicates(segments))

            segments = remove_through_trains(segments)

            yield key, segments

    def process_records(records, groupby_key, thread_keys, remove_duplicates=False):
        def collect(records_chunk):
            conn = connect()

            conn.executemany("""
                INSERT INTO search VALUES (NULL, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, [
                row
                for key, segments in process_bin(records_chunk, thread_keys, remove_duplicates)
                for segment in segments
                for row in search_rows(key, segment, extra_data(segment.record))
            ])

            conn.commit()
            conn.close()

        records = ZNodeRouteBox.iter_queryset_chunked(
            records, object_ids=record_ids, chunk_size=10000
        )

        map_groups_forked(collect, records, len(records), groupby_key,
                          chunksize=5000)

    log.info('Processing city-city, city-station, station-city... (1/2)')

    def city_thread_keys(record):
        if (record.settlement_from_id and record.settlement_to_id and
                record.settlement_from_id != record.settlement_to_id):
            yield 'c{0.settlement_from_id}-c{0.settlement_to_id}'.format(record)

        if record.settlement_from_id:
            yield 'c{0.settlement_from_id}-s{0.station_to_id}'.format(record)

        if record.settlement_to_id:
            yield 's{0.station_from_id}-c{0.settlement_to_id}'.format(record)

    process_records(
        records.exclude(settlement_from__id=None, settlement_to__id=None)
               .order_by('settlement_from__id', 'settlement_to__id'),
        lambda r: (r.settlement_from_id, r.settlement_to_id),
        city_thread_keys,
        remove_duplicates=True
    )

    log.info('Processing station-station... (2/2)')

    process_records(
        records.order_by('station_from__id', 'station_to__id'),
        lambda r: (r.station_from_id, r.station_to_id),
        lambda r: ['s{0.station_from_id}-s{0.station_to_id}'.format(r)]
    )

    if not precalc_state.partial:
        log.info('Building indexes...')

        conn.executescript("""
            CREATE INDEX search_thread_id ON search(
                thread_id
            );
            CREATE INDEX search_key_t_type_id_is_extra ON search(
                key, t_type_id, is_extra
            );
            CREATE INDEX search_key_t_type_id_brand ON search(
                key, t_type_id, brand
            );
            CREATE INDEX search_key_t_type_id_is_express ON search(
                key, t_type_id, is_express
            );
            CREATE INDEX search_key_t_type_id_number_part ON search(
                key, t_type_id, number_part
            );
        """)

    conn.commit()
    conn.close()
