# -*- coding: utf-8 -*-

from __future__ import absolute_import

import logging

from travel.avia.admin.precalc.utils.db import SQLiteSnapshot
from travel.avia.admin.precalc.parts.avia_directions import precalc_avia_directions
from travel.avia.admin.precalc.parts.search import precalc_search
from travel.avia.admin.precalc.parts.search_totals import precalc_search_totals
from travel.avia.admin.precalc.parts.stops import precalc_stops
from travel.avia.admin.precalc.parts.threads import precalc_threads


log = logging.getLogger(__name__)


def precache():
    log.info("Precaching...")

    from travel.avia.library.python.common.models.geo import (
        StationType, StationMajority, Country, Region, Settlement, StationTerminal,
        Direction, ExternalDirection, CodeSystem, DirectionTranslate, DirectionFromTranslate
    )
    from travel.avia.library.python.common.models.schedule import RThreadType, Supplier, Company, PlatformTranslation
    from travel.avia.library.python.common.models.tariffs import SuburbanTariff, TariffType, Setting
    from travel.avia.library.python.common.models.transport import TransportType, TransportModel
    from travel.avia.admin.www.models.schedule import Holiday, PreHoliday

    for model in (
            StationType, StationMajority, TransportType, TransportModel,
            Country, Region, Settlement, StationTerminal, RThreadType,
            Supplier, Holiday, PreHoliday, Direction, ExternalDirection,
            SuburbanTariff, TariffType, Company, CodeSystem, Setting,
            PlatformTranslation, DirectionTranslate, DirectionFromTranslate):

        model.objects.precache()


PARTS = (
    'threads', 'stops', 'search', 'search_totals', 'avia_directions'
)


class PrecalcState(object):

    threads_added = threads_deleted = threads_changed = searches_changed = None

    def __init__(self, partial):
        self.partial = partial


def precalc(parts=[], partial=False, hack=False):
    new = not (partial or parts)

    if not parts:
        parts = PARTS

    # check parts
    for part in parts:
        if part not in PARTS:
            raise ValueError('Unknown part %s' % part)

    snapshot = SQLiteSnapshot('schedule', new=new, hack=hack)

    connect = snapshot.connect

    precache()

    precalc_state = PrecalcState(partial)

    for i, part in enumerate(parts):
        log.info("Processing %s... (%d/%d)" % (part, i + 1, len(parts)))

        if part == 'threads':
            precalc_threads(connect, precalc_state)

        elif part == 'stops':
            precalc_stops(connect, precalc_state)

        elif part == 'search':
            precalc_search(connect, precalc_state)

        elif part == 'search_totals':
            precalc_search_totals(connect, precalc_state)

        elif part == 'avia_directions':
            precalc_avia_directions(connect, precalc_state)

    snapshot.commit()
