# -*- coding: utf-8 -*-

from __future__ import absolute_import

import json
import logging
from collections import defaultdict

from django.db.models import Q

from travel.avia.library.python.common.models.schedule import DeLuxeTrain, RThread, RThreadType, RTStation
from travel.avia.library.python.common.models.transport import TransportType

from travel.avia.admin.precalc.utils import (
    iter_slices, map_groups_forked, normalize_string, replace_water_t_type
)
from travel.avia.admin.precalc.utils.db import execute_with_conditions, pack_run_mask
from travel.avia.admin.precalc.utils.originalbox import OriginalBox
from travel.avia.admin.precalc.utils.route_number import generate_number_keys, flatten_number_key

log = logging.getLogger('precalc')

PARTIAL_PRECALC_THREADS_LIMIT = 100000


RThreadBox = OriginalBox(
    'id',
    'title',
    'title_common',
    't_model_id',
    'uid',
    'year_days',
    'time_zone',
    'type_id',
    'tz_start_time',
    'is_circular',
    'is_combined',
    'number',
    'company_id',
    'company__title',
    'schedule_plan_id',
    'express_type',
    'route_id',
    'supplier_id',
    'supplier' | OriginalBox('code'),
    't_type_id',
    't_type' | OriginalBox('code')
)

RTStationBox = OriginalBox(
    'thread_id',
    'is_virtual_end',
    'is_combined',

    'station' | OriginalBox(
        'id',
        'settlement_id',
        'not_generalize'
    )
)


class PrecalcRThread(object):

    def __init__(self, box):
        self._box = box

    def __getattr__(self, name):
        return getattr(self._box, name)

    gen_title = RThread.gen_title.im_func


def restore_title_common(title_common):
    if not title_common:
        return

    try:
        title_dict = json.loads(title_common)
    except ValueError:
        return title_common

    keys = []

    for part in title_dict['title_parts']:
        if not isinstance(part, basestring):
            # не разбираем экзотические названия
            return

        keys.append(part)

    return '_'.join(keys)


def precalc_threads(connect, precalc_state):
    threads = RThread.objects.filter(
        Q(type=None) | Q(type__in=[
            RThreadType.BASIC_ID,
            RThreadType.CHANGE_ID,
            RThreadType.THROUGH_TRAIN_ID
        ]),
        route__hidden=False,
    )

    conn = connect()

    if precalc_state.partial:
        in_precalc = set(row[0] for row in conn.execute("""SELECT id FROM thread"""))
        in_base = set(threads.values_list('id', flat=True))

        threads_deleted = precalc_state.threads_deleted = in_precalc - in_base
        threads_added = precalc_state.threads_added = in_base - in_precalc
        threads_changed = precalc_state.threads_changed = set(
            threads.filter(changed=True).values_list('id', flat=True)
        )

        log.info(
            '%s threads deleted, %s threads added, %s threads changed',
            len(threads_deleted), len(threads_added), len(threads_changed)
        )

        if not (threads_deleted or threads_added or threads_changed):
            log.info('Nothing to precalc')
            return

        if (len(threads_deleted) + len(threads_added) + len(threads_changed) >
                PARTIAL_PRECALC_THREADS_LIMIT):
            log.info('Too many affected threads for partial precalc. Falling back to complete precalc')
            precalc_state.partial = False
        else:
            log.info('Deleting stale threads...')

            for thread_ids in iter_slices((threads_deleted | threads_changed),
                                          500):
                execute_with_conditions(conn, """
                    DELETE FROM thread WHERE {}
                """, {
                    'id IN ?': thread_ids
                })

                execute_with_conditions(conn, """
                    DELETE FROM thread_number_key WHERE {}
                """, {
                    'thread_id IN ?': thread_ids
                })

            threads = threads.filter(id__in=(threads_added | threads_changed))

    if not precalc_state.partial:
        conn.executescript("""
            DROP TABLE IF EXISTS thread;

            CREATE TABLE thread (
                id INTEGER PRIMARY KEY,
                type_id INTEGER,
                route_id INTEGER NOT NULL,
                t_type_id INTEGER NOT NULL,
                number TEXT NOT NULL,
                title TEXT NOT NULL,
                title_common TEXT NOT NULL,
                direction_key TEXT,
                company_id INTEGER,
                schedule_plan_id INTEGER,
                brand TEXT,
                brand_key TEXT,
                express_type TEXT,
                t_model_id INTEGER,
                uid TEXT NOT NULL,
                run_mask BLOB NOT NULL,
                time_zone TEXT,
                for_search TEXT NOT NULL,
                first_stop_id INTEGER NOT NULL,
                last_stop_id INTEGER NOT NULL,
                last_station_id INTEGER NOT NULL
            );

            DROP TABLE IF EXISTS thread_number_key;

            CREATE TABLE thread_number_key (
                thread_id INTEGER NOT NULL,
                number_key TEXT NOT NULL
            );
        """)

    conn.commit()
    conn.close()

    def thread_row(thread):
        deluxe_train = DeLuxeTrain.get_by_number(thread.number)
        brand = deluxe_train.title if deluxe_train else ''

        return (
            thread.id,
            thread.type_id,
            thread.route_id,
            replace_water_t_type(thread.t_type_id),
            thread.number,
            thread.title,
            thread.title_common,
            thread.direction_key,
            thread.company_id,
            thread.schedule_plan_id,
            brand,
            normalize_string(brand),
            thread.express_type,
            thread.t_model_id,
            thread.uid,
            buffer(pack_run_mask(thread.year_days)),
            thread.time_zone,
            (u"%s %s|%s|%s" % (
                thread.number, thread.title, thread.company__title, brand
            )).lower(),
            thread.title_rtstations[0].id,
            thread.title_rtstations[-1].id,
            thread.rtstations[-1].station_id,
        )

    def process_threads(threads):
        rtstations = RTStation.objects.filter(
            thread__in=[thread.id for thread in threads]
        ).select_related('station').order_by('id')

        rtstations_by_thread = defaultdict(list)

        for rts in rtstations:
            rtstations_by_thread[rts.thread_id].append(rts)

        filtered_threads = []

        for thread in threads:
            thread_rtstations = rtstations_by_thread[thread.id]

            if len(thread_rtstations) < 2:
                # пропускаем нитки с одной станцией и без станций
                continue

            # совместимость с title_generator
            thread = PrecalcRThread(thread)
            thread.rtstations = thread_rtstations
            thread.gen_title()

            # в качестве ключа группировки направлений используем title_common в старом формате
            thread.direction_key = restore_title_common(thread.title_common)

            # RASP-10130 - Для автобусных рейсов формировать названия вкладок по принципу город - город
            if thread.t_type_id == TransportType.BUS_ID:
                station_from, station_to = (
                    thread.title_rtstations[0].station,
                    thread.title_rtstations[-1].station
                )

                if station_from.settlement_id != station_to.settlement_id:
                    thread.direction_key = '_'.join((
                        'c{0.settlement_id}' if s.settlement_id else 's{0.id}'
                    ).format(s) for s in (station_from, station_to))

            filtered_threads.append(thread)

        conn = connect()

        conn.executemany("""
            INSERT INTO thread VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, map(thread_row, filtered_threads))

        conn.executemany("""
            INSERT INTO thread_number_key VALUES (?, ?)
        """, [
            (t.id, flatten_number_key(number_key))
            for t in filtered_threads
            for number_key in generate_number_keys(t.t_type_id, t.number)
        ])

        conn.commit()
        conn.close()

    threads = RThreadBox.iter_queryset_chunked(threads, chunk_size=10000)

    map_groups_forked(process_threads, threads, len(threads),
                      lambda thread: thread.id, chunksize=5000)

    if not precalc_state.partial:
        conn = connect()

        log.info("Building indexes...")

        conn.executescript("""
            CREATE INDEX thread_brand_key ON thread(brand_key);
            CREATE INDEX thread_number_key_number_key_thread_id ON thread_number_key(number_key, thread_id);
        """)

        conn.close()
