# -*- coding: utf-8 -*-

from __future__ import absolute_import

import cPickle
import logging
import sqlite3
from bisect import bisect_left, bisect_right
from datetime import datetime, time, timedelta
from itertools import dropwhile, groupby, takewhile

from travel.avia.library.python.common.models.schedule import TrainSchedulePlan
from travel.avia.library.python.common.utils.date import RunMask

from travel.avia.admin.precalc.utils import iter_slices, map_groups_forked
from travel.avia.admin.precalc.utils.db import TmpDB, execute_with_conditions, time_int, unpack_run_mask


log = logging.getLogger('precalc')


class Direction(object):
    pass


def precalc_search_totals(connect, precalc_state):
    if precalc_state.partial and not precalc_state.searches_changed:
        log.info('Nothing to precalc')
        return

    now = datetime.now()
    today = now.date()

    days = []

    log.info('Precomputing local search windows... (1/3)')

    local_date = today

    while local_date < today + timedelta(days=365):
        days.append((
            local_date,
            datetime.combine(local_date, time(0, 0)),
            datetime.combine(local_date + timedelta(days=1), time(4, 0)),
        ))

        local_date += timedelta(days=1)

    schedule_plans = {
        plan.id: plan for plan in TrainSchedulePlan.objects.all()
    }

    def process(search_rows):
        p_conn = connect()

        directions = []
        thread_ids = set()

        for key, key_rows in groupby(search_rows, lambda row: row['key']):
            d = Direction()

            d.key = key
            d.records = list(key_rows)

            directions.append(d)

            thread_ids.update(r['thread_id'] for r in d.records)

        threads_days = {}

        for thread_id in thread_ids:
            thread_row = p_conn.execute("""
                SELECT run_mask, schedule_plan_id FROM thread WHERE id = ?
            """, [thread_id]).fetchone()

            if thread_row is None:
                continue

            run_mask, schedule_plan_id = thread_row

            thread_dates = RunMask(
                unpack_run_mask(run_mask), today=today
            ).iter_dates()

            if schedule_plan_id:
                schedule_plan = schedule_plans[schedule_plan_id]

                thread_dates = takewhile(
                    lambda date: date <= schedule_plan.end_date,
                    dropwhile(
                        lambda date: date < schedule_plan.start_date,
                        thread_dates
                    )
                )

            threads_days[thread_id] = list(thread_dates)

        rows = []

        for d in directions:
            times = {}
            total_counts = {}

            for r in d.records:
                thread_days = threads_days.get(r['thread_id'])

                if not thread_days:
                    continue

                t_type_id = r['t_type_id']

                keys = [(t_type_id, None)]

                if r['is_express']:
                    keys.append((t_type_id, 'is_express'))

                for key in keys:
                    total_counts[key] = total_counts.get(key, 0) + 1

                days_delta = timedelta(days=r['dep_day_shift'])
                dep_time = time_int(r['dep_time'])

                for day in thread_days:
                    dt = datetime.combine(day + days_delta, dep_time)

                    for key in keys:
                        times.setdefault(key, []).append(dt)

            for t in times.values():
                t.sort()

            for key, key_times in times.items():
                total_count = total_counts[key]

                t_type_id, f = key

                counts = {}

                for local_date, day_start, day_end in days:
                    count = (bisect_right(key_times, day_end) -
                             bisect_left(key_times, day_start))

                    if count:
                        counts[local_date] = count

                rows.append((
                    f, d.key, t_type_id, total_count,
                    buffer(cPickle.dumps(counts, cPickle.HIGHEST_PROTOCOL))
                ))

        p_conn.close()

        p_conn = connect_w()

        p_conn.executemany("""INSERT INTO search_totals VALUES (?, ?, ?, ?, ?)""", rows)

        p_conn.commit()

        p_conn.close()

    conn = connect()
    conn.row_factory = sqlite3.Row

    if precalc_state.partial:
        connect_w = connect

        log.info('Deleting stale records...')

        searches_changed = precalc_state.searches_changed

        # разбиваем на куски по 500 параметров из-за ограничения в SQLite
        for search_keys_part in iter_slices(searches_changed, 500):
            execute_with_conditions(conn, """
                DELETE FROM search_totals WHERE {}
            """, {
                'key IN ?': search_keys_part
            })

        conn.commit()

    else:
        tmp_db = TmpDB()
        tmp_db.cleanup()

        tmp_conn = tmp_db.connect()

        tmp_conn.executescript("""
            DROP TABLE IF EXISTS search_totals;

            CREATE TABLE search_totals (
                filter TEXT,
                key TEXT NOT NULL,
                t_type_id INTEGER NOT NULL,
                total INTEGER NOT NULL,
                data BLOB NOT NULL
            );
        """)

        tmp_conn.commit()

        tmp_conn.close()

        connect_w = tmp_db.connect

    log.info('Computing totals... (2/3)')

    rows_query = """
        SELECT key, thread_id, dep_time, dep_day_shift, t_type_id, is_express
        FROM search
        WHERE {}
        ORDER BY key
    """

    if precalc_state.partial:
        # выгружаем все в память, чтобы не лочить базу
        rows = [
            row
            for search_keys_part in iter_slices(searches_changed, 500)
            for row in execute_with_conditions(conn, rows_query, {
                'NOT is_extra': None, 'key IN ?': search_keys_part})
        ]
        count = len(rows)
    else:
        rows = execute_with_conditions(conn, rows_query,
                                       {'NOT is_extra': None})
        count = conn.execute("""
            SELECT count(*) FROM search WHERE NOT is_extra
        """).fetchone()[0]

    map_groups_forked(process, rows, count, lambda row: row['key'])

    if precalc_state.partial:
        conn.close()
    else:
        log.info("Copying totals... (3/3)")

        conn.execute("""ATTACH DATABASE ? AS scheduletmp""", [tmp_db.db_path])

        conn.executescript("""
            DROP TABLE IF EXISTS main.search_totals;

            CREATE TABLE search_totals (
                filter TEXT,
                key TEXT NOT NULL,
                t_type_id INTEGER NOT NULL,
                total INTEGER NOT NULL,
                data BLOB NOT NULL
            );

            INSERT INTO search_totals SELECT * FROM scheduletmp.search_totals;

            CREATE INDEX search_totals_key_filter_t_type_id ON search_totals(key, filter, t_type_id);
        """)

        conn.commit()

        conn.close()

        tmp_db.cleanup()
