# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
from datetime import datetime
from functools import partial

import pytz

from common.apps.suburban_events.models import SuburbanKey
from common.models.geo import Station, Settlement
from common.models.schedule import RTStation, RThread, RThreadType, RTStationLogicMixin
from common.models.transport import TransportType
from common.models_utils.geo import Point
from common.utils.date import RunMask
from travel.rasp.library.python.common23.logging import log_run_time
from common.utils.mysql_try_hard import mysql_try_hard
from travel.rasp.info_center.info_center.suburban_notify.utils import cached, InitableSlots


log = logging.getLogger(__name__)
log_run_time = partial(log_run_time, logger=log)

suburban_type = TransportType.objects.get(id=TransportType.SUBURBAN_ID)


class Obj(InitableSlots):
    NOTHING = object

    DJANGO_MODEL = None
    DB_FIELDS = []
    __slots__ = ()
    cache = {}

    def __unicode__(self):
        return '{} {}'.format(self.__class__.__name__, self.id)

    def __str__(self):
        return self.__unicode__().encode('utf8')

    def __repr__(self):
        return self.__str__()

    @classmethod
    def clear_cache(cls):
        cls.cache = {}

    @classmethod
    def get(cls, obj_id, default=NOTHING):
        if default is Obj.NOTHING:
            return cls.cache[obj_id]
        else:
            return cls.cache.get(obj_id)

    @classmethod
    @mysql_try_hard
    def load_objs_from_query(cls, query):
        objs = [cls(*fields) for fields in query.values_list(*cls.DB_FIELDS)]
        for obj in objs:
            cls.cache[obj.id] = obj
        return objs

    @classmethod
    def load_objs_by_ids(cls, ids, model=None):
        uncached_ids = list(set(ids) - set(cls.ids()))
        model = cls.DJANGO_MODEL or model

        with log_run_time('get {} ({}) of {}'.format(len(uncached_ids), len(ids), model)):
            objs = cls.load_objs_from_query(model.objects.filter(id__in=uncached_ids))
            log.debug('got %s of %s', len(objs), model)

        found_objs, not_found_ids = [], []
        for obj_id in ids:
            obj = cls.get(obj_id, None)
            if obj:
                found_objs.append(obj)
            else:
                not_found_ids.append(obj_id)

        return found_objs, not_found_ids

    @classmethod
    def ids(cls):
        return cls.cache.keys()

    @classmethod
    def all(cls):
        return cls.cache.values()

    @property
    def pytz(self):
        return self.get_pytz()

    @cached
    def get_pytz(self):
        return pytz.timezone(self.time_zone)


class TRts(Obj, RTStationLogicMixin):
    DJANGO_MODEL = RTStation
    DB_FIELDS = ['id', 'tz_arrival', 'tz_departure', 'station', 'thread', 'time_zone']
    MANUAL_FIELDS = ['arrival', 'departure']
    __slots__ = DB_FIELDS + MANUAL_FIELDS
    cache = {}

    def __init__(self, *args):
        super(TRts, self).__init__(*args)

    def __unicode__(self):
        return '{} {} {} {}'.format(self.id, self.station.id, self.tz_arrival, self.tz_departure)

    @cached
    def get_start_date_for_event(self, event, event_date, use_full_mask=True):
        start_date = self.calc_thread_start_date(event, event_date)
        if use_full_mask:
            if self.thread.mask_full[start_date]:
                return start_date
        elif self.thread.mask[start_date]:
            return start_date

    @cached
    def get_departure_dt(self, naive_start_dt, out_tz=None):
        return super(TRts, self).get_departure_dt(naive_start_dt, out_tz)

    @cached
    def get_arrival_dt(self, naive_start_dt, out_tz=None):
        return super(TRts, self).get_arrival_dt(naive_start_dt, out_tz)

    def get_event_dt_by_date(self, event, start_date, out_tz=None):
        naive_start_dt = datetime.combine(start_date, self.thread.tz_start_time)
        return self.get_event_dt(event, naive_start_dt, out_tz)


class TThread(Obj):
    DJANGO_MODEL = RThread
    DB_FIELDS = [
        'id', 'number', 'uid', 'type', 'tz_start_time', 'time_zone', 'year_days', 'basic_thread_id', 'title'
    ]
    MANUAL_FIELDS = ['first_rts', 'mask', 'mask_full', 'related_threads', 'path', 'suburban_key']
    __slots__ = DB_FIELDS + MANUAL_FIELDS
    cache = {}

    def __init__(self, *args):
        super(TThread, self).__init__(*args)
        self.related_threads = []
        self.path = []

    @classmethod
    def basic_threads(cls):
        return [t for t in cls.all() if t.type == RThreadType.BASIC_ID]

    @cached
    def get_start_time_minutes(self):
        """ Время старта нитки в минутах от начала дня"""
        return self.tz_start_time.hour * 60 + self.tz_start_time.minute

    def first_rts(self):
        return self.path[0]

    def last_rts(self):
        return self.path[-1]


class TStation(Obj):
    DJANGO_MODEL = Station
    DB_FIELDS = ['id', 'settlement', 'title_ru', 'time_zone']
    __slots__ = DB_FIELDS
    cache = {}

    def L_title(self):
        return self.title_ru

    def get_key(self):
        return 's' + str(self.id)


class TSettlement(Obj):
    DJANGO_MODEL = Settlement
    DB_FIELDS = ['id', 'time_zone', 'title_ru']
    __slots__ = DB_FIELDS
    cache = {}

    def L_title(self):
        return self.title_ru

    def get_key(self):
        return 'c' + str(self.id)


def fetch_related(objs, field, cls):
    with log_run_time('fetch_related {}'.format(field)):
        rel_obj_ids = set(getattr(obj, field) for obj in objs)
        rel_objs, not_found_objs = cls.load_objs_by_ids(rel_obj_ids)

        for obj in objs:
            setattr(obj, field, cls.get(getattr(obj, field)))

    return rel_objs, not_found_objs


def get_thread_rtses(threads):
    rts_query = RTStation.objects.filter(thread_id__in=[t.id for t in threads])
    for rts in TRts.load_objs_from_query(rts_query):
        rts.thread = TThread.get(rts.thread)  # thread_id -> thread
        rts.thread.path.append(rts)  # rtses are already sorted by id


def find_related_threads():
    for t in TThread.all():
        if t.basic_thread_id is not None:
            try:
                basic_thread = TThread.get(t.basic_thread_id)
            except KeyError:
                # нормально, что не все базовые нитки у нас достаны
                pass
            else:
                basic_thread.related_threads.append(t)


def load_threads(rel_thread_ids):
    threads, not_found_threads = TThread.load_objs_by_ids(rel_thread_ids)

    # При поиске не все связанные нитки попадают к нам (напр, если станции в этой нитке отменены).
    # Тут мы загружаем в память все связанные нитки, которых мы еще не видели.
    # Их мало, и проще загрузить все, чем искать, какие относятся к уже загруженным базовым ниткам.
    with log_run_time('load all related threads'):
        all_rel_threads = RThread.objects.filter(t_type_id=TransportType.SUBURBAN_ID).exclude(type=RThreadType.BASIC_ID)
        rel_thread_ids = list(all_rel_threads.values_list('id', flat=True))
        TThread.load_objs_by_ids(rel_thread_ids)

    with log_run_time('find_related_threads'):
        find_related_threads()

    with log_run_time('convert masks'):
        for t in TThread.all():
            t.mask = RunMask(t.year_days)

    with log_run_time('get full masks'):
        for thread in TThread.basic_threads():
            thread.mask_full = RunMask(thread.mask)
            for rel_thread in thread.related_threads:
                thread.mask_full |= rel_thread.mask

    with log_run_time('get_all_rtses'):
        get_thread_rtses(TThread.all())

    with log_run_time('load thread suburban keys'):
        sk_query = SuburbanKey.objects.filter(thread_id__in=TThread.ids())
        threads_with_keys = set()
        for thread_id, key in sk_query.values_list('thread_id', 'key'):
            TThread.get(thread_id).suburban_key = key
            threads_with_keys.add(thread_id)

        not_found = set(TThread.ids()) - threads_with_keys
        if not_found:
            log.error('Suburban keys not found for threads: %s', not_found)

    return TThread.all(), not_found_threads


@mysql_try_hard
def load_db_data(rts_ids, settlement_ids=None, station_ids=None):
    settlement_ids = settlement_ids or []
    station_ids = station_ids or []

    settlements, not_found_settlements = TSettlement.load_objs_by_ids(settlement_ids)
    stations, not_found_stations = TStation.load_objs_by_ids(station_ids)

    thread_ids = set(RTStation.objects.filter(id__in=rts_ids).values_list('thread_id', flat=True))
    threads, not_found_threads = load_threads(thread_ids)

    stations, not_found_stations = fetch_related(TRts.all(), 'station', TStation)

    log.info('Threads: %s, not found %s', len(TThread.all()), len(not_found_threads))
    log.info('Rtses: %s', len(TRts.all()))
    log.info('Stations: %s, not found %s', len(TStation.all()), len(not_found_stations))
    log.info('Settlements: %s, not found %s', len(TSettlement.all()), len(not_found_settlements))


def is_settlement_key(point_id):
    return point_id.startswith('c')


def parse_point_key(point_key):
    model, obj_id = Point.parse_key(point_key)
    obj_id = int(obj_id)
    if model is Station:
        return TStation, obj_id
    elif model is Settlement:
        return TSettlement, obj_id
    else:
        raise ValueError(point_key)


def get_point_by_key(point_key):
    cls, obj_id = parse_point_key(point_key)
    return cls.get(obj_id)


def load_points(point_keys):
    settlement_ids, station_ids = set(), set()
    for point_key in point_keys:
        cls, obj_id = parse_point_key(point_key)
        if cls is TSettlement:
            settlement_ids.add(obj_id)
        elif cls is TStation:
            station_ids.add(obj_id)

    TSettlement.load_objs_by_ids(settlement_ids)
    TStation.load_objs_by_ids(station_ids)


def clear_caches():
    for cls in [TRts, TThread, TStation, TSettlement]:
        cls.clear_cache()
