# -*- coding: utf-8 -*-

from __future__ import absolute_import

import cPickle
import logging
from itertools import groupby

from django.db.models.constants import LOOKUP_SEP

from travel.avia.library.python.common.models.geo import Station
from travel.avia.library.python.common.models.schedule import RThreadType, RTStation
from travel.avia.library.python.common.models.transport import TransportType
from travel.avia.library.python.stationschedule.type.suburban import SuburbanSchedule

from travel.avia.admin.precalc.utils import iter_slices, map_groups_forked, replace_water_t_type
from travel.avia.admin.precalc.utils.db import execute_with_conditions, int_time
from travel.avia.admin.precalc.utils.originalbox import OriginalBox


log = logging.getLogger('precalc')

RTStationBox = OriginalBox(
    'id',
    'tz_arrival',
    'tz_departure',
    'time_zone',
    'station_id',
    'departure_direction_id',
    'arrival_subdir',
    'departure_subdir',
    'is_technical_stop',
    'platform',
    'terminal_id',
    'arrival_t_model_id',
    'departure_t_model_id',
    'is_fuzzy',
    'in_station_schedule',
    'schedule__stops',

    'thread_id',
    'thread' | OriginalBox(
        'type_id',
        'tz_start_time',
        'time_zone',
        'year_days',
        'schedule_plan_id',
        'show_in_alldays_pages',

        'route' | OriginalBox(
            'hidden',
            't_type_id'
        )
    ),
)


class PrecalcRTStation(object):

    def __init__(self, box):
        self._box = box

    L_arrival_subdir = RTStation.L_arrival_subdir
    L_departure_subdir = RTStation.L_departure_subdir

    def __getattr__(self, name):
        return getattr(self._box, name)


class SuburbanPrecalcSchedule(SuburbanSchedule):

    def __init__(self, station, rtstations):
        super(SuburbanPrecalcSchedule, self).__init__(
            station, station.settlement, event='departure'
        )

        self._rtstations = rtstations

    def iter_rtstations_fields(self):
        rtstations = [
            rts
            for rts in self._rtstations
            if (rts.in_station_schedule and
                not rts.is_technical_stop and (
                    rts.tz_departure != rts.tz_arrival or
                    rts.tz_departure is None or
                    rts.tz_arrival is None
                ) and
                not rts.thread.route.hidden and
                rts.thread.type_id not in (RThreadType.THROUGH_TRAIN_ID,
                                           RThreadType.INTERVAL_ID,
                                           RThreadType.CANCEL_ID))
        ]
        field_names = self.get_fetching_fields()

        for rtstation in rtstations:
            yield (
                rtstation,
                {
                    field: reduce(lambda obj, piece: getattr(obj, piece, None),
                                  field.split(LOOKUP_SEP), rtstation)
                    for field in field_names
                }
            )


def precalc_stops(connect, precalc_state):
    rtstations = RTStation.objects.all()

    conn = connect()

    if precalc_state.partial:
        in_precalc = set(row[0]
                         for row in conn.execute("""SELECT id FROM stop"""))
        in_base = set(rtstations.values_list('id', flat=True))

        stops_deleted = in_precalc - in_base
        stops_added = in_base - in_precalc
        stops_changed = set(rtstations.filter(
            thread__id__in=precalc_state.threads_changed
        ).values_list('id', flat=True))

        log.info(
            '%s stops deleted, %s stops added, %s stops changed',
            len(stops_deleted), len(stops_added), len(stops_changed)
        )

        if not (stops_deleted or stops_added or stops_changed):
            log.info('Nothing to precalc')
            return

        stations_changed = set(
            station_id
            for stop_ids in iter_slices((stops_deleted | stops_changed), 500)
            for (station_id,) in execute_with_conditions(conn, """
                SELECT DISTINCT station_id FROM stop WHERE {}
            """, {
                'id IN ?': stop_ids
            })
        ).union(rtstations.filter(id__in=stops_added)
                          .values_list('station_id', flat=True))

        log.info('Deleting stale stops and stations...')

        for stop_ids in iter_slices((stops_deleted | stops_changed), 500):
            execute_with_conditions(conn, """
                DELETE FROM stop WHERE {}
            """, {
                'id IN ?': stop_ids
            })

        for station_ids in iter_slices(stations_changed, 500):
            execute_with_conditions(conn, """
                DELETE FROM station WHERE {}
            """, {
                'id IN ?': station_ids
            })

        rtstations = rtstations.filter(id__in=(stops_added | stops_changed))

    else:
        conn.executescript("""
            DROP TABLE IF EXISTS stop;

            CREATE TABLE stop (
                id INTEGER PRIMARY KEY,
                station_id INTEGER NOT NULL,
                t_type_id INTEGER NOT NULL,
                thread_id INTEGER NOT NULL,
                platform TEXT,
                terminal_id INTEGER,
                direction TEXT,
                stops TEXT,
                arrival REAL,
                departure REAL,
                arr_day_shift INTEGER,
                arr_time INTEGER,
                dep_day_shift INTEGER,
                dep_time INTEGER,
                suburban_day_shift INTEGER,
                suburban_time INTEGER,
                time_zone TEXT,
                arr_t_model_id INTEGER,
                dep_t_model_id INTEGER
            );

            CREATE INDEX stop_station_id_arr_time_thread_id ON stop(station_id, t_type_id, arr_time, thread_id);
            CREATE INDEX stop_station_id_dep_time_thread_id ON stop(station_id, t_type_id, dep_time, thread_id);
            CREATE INDEX stop_station_id_t_type_id_direction_suburban_time_thread_id ON stop(station_id, t_type_id, direction, suburban_time, thread_id);
            CREATE INDEX stop_station_id_t_type_id_suburban_time_thread_id ON stop(station_id, t_type_id, suburban_time, thread_id);
            CREATE INDEX stop_thread_id ON stop(thread_id);

            DROP TABLE IF EXISTS station;

            CREATE TABLE station (
                id INTEGER PRIMARY KEY,
                data BLOB NOT NULL
            );
        """)

    conn.commit()

    conn.close()

    def stop_row(rtstation, direction_code=None):
        start_time = int_time(rtstation.thread.tz_start_time)

        arr_day_shift = arr_time = None
        dep_day_shift = dep_time = None

        if rtstation.tz_arrival is not None:
            arr_day_shift, arr_time = divmod(start_time + rtstation.tz_arrival * 60, 86400)

        if rtstation.tz_departure is not None:
            dep_day_shift, dep_time = divmod(start_time + rtstation.tz_departure * 60, 86400)

        if dep_time is None:
            # Конечная
            suburban_day_shift = arr_day_shift
            suburban_time = arr_time
        else:
            suburban_day_shift = dep_day_shift
            suburban_time = dep_time

        return (
            rtstation.id,
            rtstation.station_id,
            replace_water_t_type(rtstation.thread.route.t_type_id),
            rtstation.thread_id,
            rtstation.platform,
            rtstation.terminal_id,
            direction_code,
            rtstation.schedule__stops,
            rtstation.tz_arrival,
            rtstation.tz_departure,
            arr_day_shift,
            arr_time,
            dep_day_shift,
            dep_time,
            suburban_day_shift,
            suburban_time,
            rtstation.time_zone,
            rtstation.arrival_t_model_id,
            rtstation.departure_t_model_id,
        )

    def process(rtstations):
        stations = Station.objects.in_bulk(rts.station_id for rts in rtstations)

        station_info_rows = []
        stop_rows = []

        for station_id, rtstations in groupby(rtstations, lambda rts: rts.station_id):
            rtstations = list(rtstations)

            suburban_rtstations = [
                PrecalcRTStation(rts)
                for rts in rtstations
                if rts.thread.route.t_type_id == TransportType.SUBURBAN_ID
            ]
            direction_codes = {}

            if suburban_rtstations:
                schedule = (
                    SuburbanSchedule(stations[station_id], event='departure')
                    if precalc_state.partial else
                    SuburbanPrecalcSchedule(
                        stations[station_id], suburban_rtstations
                    )
                ).build()

                station_info = {
                    'default_direction': schedule.direction_code,
                    'sorted_directions':
                        schedule.direction_code_title_count_list
                }

                station_info_rows.append((
                    station_id,
                    buffer(cPickle.dumps(station_info, cPickle.HIGHEST_PROTOCOL)),
                ))

                direction_codes.update({
                    schedule_item.rtstation.id: schedule_item.direction_code
                    for schedule_item in schedule.schedule_items
                })

            stop_rows.extend(
                stop_row(rts, direction_codes.get(rts.id))
                for rts in rtstations
            )

        conn = connect()

        conn.executemany("""INSERT INTO station VALUES (?, ?)""", station_info_rows)

        conn.executemany("""INSERT INTO stop VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", stop_rows)

        conn.commit()

        conn.close()

    rtstations = RTStationBox.iter_queryset_chunked(rtstations.order_by('station'))

    map_groups_forked(process, rtstations, len(rtstations),
                      lambda rts: rts.station_id)
