# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import codecs
from datetime import datetime, timedelta, date
from functools import partial
from itertools import groupby
from multiprocessing import Pool

from django.db import connection

from common.models.schedule import RTStation, RThreadType
from common.models.transport import TransportType
from travel.rasp.library.python.common23.date import environment
from common.utils.iterrecipes import pairwise
from common.utils.date import MSK_TZ, RunMask, timedelta2minutes
from common.utils.progress import PercentageStatus
from common.utils.tz_mask_split import MaskSplitter, StationForMaskSplit, ThreadForMaskSplit, thread_mask_split

from travel.rasp.rasp_scripts.scripts.pathfinder.helpers import get_to_pathfinder_year_days_converter
from travel.rasp.rasp_scripts.scripts.pathfinder.tmpfiles import get_tmp_filepath


log = logging.getLogger(__name__)


def _filter_minus_one_duration_hack(rtstations):
    for rts_from, rts_to in pairwise(rtstations):
        if rts_to.msk_arr_dt is None or rts_from.msk_dep_dt is None:
            continue

        if (-1 <= timedelta2minutes(rts_to.msk_arr_dt - rts_from.msk_dep_dt) < 0 and
                (rts_to.msk_dep_dt is None or rts_to.msk_dep_dt >= rts_from.msk_dep_dt)):
            rts_to.msk_arr_dt = rts_from.msk_dep_dt


def _rtstations_msk_to_msk_tz(thread, rtstations, mask):
    # Дата не влияет на вычисления, т.к. мы не ковертируем в другие временные зоны
    some_date = date(2015, 1, 1)
    naive_start_date = some_date
    naive_start_dt = datetime.combine(naive_start_date, thread.tz_start_time)

    for rts in rtstations:
        rts.msk_dep_dt = naive_start_dt + timedelta(minutes=rts.tz_departure) if rts.tz_departure is not None else None
        rts.msk_arr_dt = (naive_start_dt + timedelta(minutes=rts.tz_arrival)
                          if rts.tz_arrival is not None else naive_start_dt)
        rts.msk_dep_mask = (mask.shifted((rts.msk_dep_dt.date() - naive_start_date).days)
                            if rts.msk_dep_dt is not None else None)


def _rtstations_to_msk_tz(thread, rtstations, mask):
    naive_start_date = mask.iter_dates().next()
    naive_start_dt = datetime.combine(naive_start_date, thread.tz_start_time)

    for rts in rtstations:
        naive_departure_dt = (naive_start_dt + timedelta(minutes=rts.tz_departure)
                              if rts.tz_departure is not None else None)
        rts.msk_dep_dt = (rts.pytz.localize(naive_departure_dt).astimezone(MSK_TZ)
                          if rts.tz_departure is not None else None)
        rts.msk_dep_mask = (mask.shifted((rts.msk_dep_dt.date() - naive_start_date).days)
                            if rts.tz_departure is not None else None)

        naive_arrival_dt = (naive_start_dt + timedelta(minutes=rts.tz_arrival)
                            if rts.tz_arrival is not None else naive_start_dt)
        rts.msk_arr_dt = rts.pytz.localize(naive_arrival_dt).astimezone(MSK_TZ) if rts.tz_arrival is not None else None


def _filter_code_sharing(thread, rtstations):
    any_segment_is_codesharing = any(rts_from.departure_code_sharing and rts_to.arrival_code_sharing
                                     for rts_from, rts_to in pairwise(rtstations))
    if any_segment_is_codesharing:
        thread.type_id = RThreadType.THROUGH_TRAIN_ID


def _generate_thegraph_string(thread, rtstations, year_days_func, mask_index=0):
    result_rows = []
    is_correct_thread = True

    for rts_from, rts_to in pairwise(rtstations):
        if rts_from.departure_code_sharing and rts_to.arrival_code_sharing and len(rtstations) <= 2:
            continue

        if rts_to.msk_arr_dt is None or rts_from.msk_dep_dt is None:
            continue

        fuzzy_flag = rts_from.is_fuzzy or rts_from.is_technical_stop

        from_stay_duration = (
            int(timedelta2minutes(rts_from.msk_dep_dt - rts_from.msk_arr_dt)) if rts_from.msk_arr_dt else 0
        )
        trip_duration = int(timedelta2minutes(rts_to.msk_arr_dt - rts_from.msk_dep_dt))

        if from_stay_duration < 0 or trip_duration < 0:
            is_correct_thread = False

        result_rows.append((
            rts_from.station_id,
            rts_to.station_id,
            '1' if fuzzy_flag else '0',
            '{}({})'.format(thread.uid, mask_index) if mask_index else thread.uid,
            rts_from.msk_dep_dt.time(),
            from_stay_duration,
            trip_duration,
            (thread.number or thread.uid or ''),
            33 if thread.express_type == 'express' else thread.t_type_id,
            thread.type_id,
            year_days_func(rts_from.msk_dep_mask)
        ))

    if is_correct_thread:
        for row in result_rows:
            yield map(unicode, row)


def _generate_for_msk_zone(thread, rtstations, mask, year_days_func):
    _rtstations_msk_to_msk_tz(thread, rtstations, mask)
    _filter_minus_one_duration_hack(rtstations)
    _filter_code_sharing(thread, rtstations)
    for row in _generate_thegraph_string(thread, rtstations, year_days_func):
        yield row


def _generate_for_msk_invariant_mask(thread, rtstations, msk_invariant_mask, year_days_func, index=0):
    _rtstations_to_msk_tz(thread, rtstations, msk_invariant_mask)
    _filter_minus_one_duration_hack(rtstations)
    _filter_code_sharing(thread, rtstations)
    for row in _generate_thegraph_string(thread, rtstations, year_days_func, index):
        yield row


def _make_rasp_db_thread_for_mask_split(mask, rtstations, init_tz):

    # Дата не влияет на вычисления, т.к. мы вычисляем только время в зоне остановки
    some_date = date(2015, 1, 1)
    naive_start_date = some_date
    naive_start_dt = datetime.combine(naive_start_date, init_tz)

    stations = []
    for rts in rtstations:
        if rts.tz_departure is not None:
            naive_departure_dt = naive_start_dt + timedelta(minutes=rts.tz_departure)
            departure_day_shift = (naive_departure_dt.date() - naive_start_date).days
            departure_time = naive_departure_dt.time()
        else:
            departure_day_shift = 0
            departure_time = None

        if rts.tz_arrival is not None:
            naive_arrival_dt = naive_start_dt + timedelta(minutes=rts.tz_arrival)
            arrival_day_shift = (naive_arrival_dt.date() - naive_start_date).days
            arrival_time = naive_arrival_dt.time()
        else:
            arrival_day_shift = 0
            arrival_time = None

        stations.append(
            StationForMaskSplit(
                station_pytz=rts.pytz,
                arrival_time=arrival_time,
                arrival_day_shift=arrival_day_shift,
                departure_time=departure_time,
                departure_day_shift=departure_day_shift,
            )
        )

    return ThreadForMaskSplit(mask, stations)


def _split_thread_mask_to_given_tz_invariant_chunks(mask, rtstations, init_tz, out_pytz, mask_splitter):
    """
    Возвращает набор масок, таких что, для всех дней хождений из маски,
    времена прибытия и отправления в таймзоне out_tz не меняются.
    """
    thread_for_split = _make_rasp_db_thread_for_mask_split(mask, rtstations, init_tz)
    return thread_mask_split(thread_for_split, mask_splitter, out_pytz)


def _get_rows_for_thread(thread, rtstations, today, year_days_func, mask_splitter):
    mask = RunMask(thread.year_days, today=today)

    path_zones = set([rts.pytz for rts in rtstations])
    if path_zones == {MSK_TZ}:
        for row in _generate_for_msk_zone(thread, rtstations, mask, year_days_func):
            yield row
        return

    msk_invariant_masks = _split_thread_mask_to_given_tz_invariant_chunks(
        mask, rtstations, thread.tz_start_time, MSK_TZ, mask_splitter
    )
    for index, msk_invariant_mask in enumerate(msk_invariant_masks):
        for row in _generate_for_msk_invariant_mask(
            thread, rtstations, msk_invariant_mask, year_days_func, index
        ):
            yield row


def _gen_rows(rts_chunk, today):
    connection.close()
    mask_splitter = MaskSplitter()
    rts_iter = RTStation.objects.filter(pk__in=rts_chunk).prefetch_related('thread')

    thread_group_iter = groupby(rts_iter, lambda rts: rts.thread)

    year_days_func = get_to_pathfinder_year_days_converter(today)

    for thread, thread_rts_iter in thread_group_iter:
        rtstations = list(thread_rts_iter)
        for row in _get_rows_for_thread(thread, rtstations, today, year_days_func, mask_splitter):
            yield row


def _process_thegraph_chunk(task, today):
    try:
        out_file, rts_chunk = task

        with codecs.open(out_file, 'w', encoding='utf-8') as f:
            for row in _gen_rows(rts_chunk, today):
                f.write('\t'.join(row))
                f.write('\n')

        return len(rts_chunk), out_file
    except Exception as ex:
        log.exception(u'Error during the generation of thegraph in worker. {}'.format(ex.message))
        raise


def gen_thegraph_from_rasp_db(thegraph_path):
    log.info('Generation of the file thegraph from rasp database')

    with open(thegraph_path, 'w') as result_file:

        rts_filtered = RTStation.objects.filter(thread__route__hidden=False)\
            .exclude(thread__t_type_id__in={TransportType.PLANE_ID, TransportType.HELICOPTER_ID})

        rts_query_set = rts_filtered.exclude(thread__year_days=RunMask.EMPTY_YEAR_DAYS) \
            .filter(thread__type__in=[RThreadType.BASIC_ID, RThreadType.THROUGH_TRAIN_ID, RThreadType.CHANGE_ID,
                                      RThreadType.ASSIGNMENT_ID, RThreadType.INTERVAL_ID]) \
            .select_related('thread', 'thread__route') \
            .order_by('thread__id', 'id')

        chunksize = 5000

        rts_thread_ids = rts_query_set.values_list('id', 'thread_id')

        def rts_chunk_iter():
            chunk = []

            for thread_id, rts_thread_chunk in groupby(rts_thread_ids, lambda p: p[1]):
                chunk.extend([p[0] for p in rts_thread_chunk])

                if len(chunk) >= chunksize:
                    out_file = get_tmp_filepath('thread_graph_part', 'pathfinder')
                    yield out_file, chunk
                    chunk = []

            if chunk:
                out_file = get_tmp_filepath('thread_graph_part', 'pathfinder')
                yield out_file, chunk

        connection.close()
        process_pool = Pool(10)

        result_iter = process_pool.imap_unordered(
            partial(_process_thegraph_chunk, today=environment.today()),
            rts_chunk_iter()
        )

        status = PercentageStatus(rts_query_set.order_by().count(), log)

        for count, result_filepath in result_iter:
            status.step(count)
            with open(result_filepath) as result_f:
                result_file.write(result_f.read())

    log.info('Generation of thegraph from rasp database is completed')
