# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
import os
from dateutil import parser
from datetime import datetime, timedelta
from collections import defaultdict
from itertools import chain, combinations

from django.db.models import Q, F

from common.apps.archival_data.models import ArchivalSearchData, ArchivalSettlementsData
from common.data_api.baris.all_flights import all_flights_iterate
from common.data_api.baris.instance import baris
from common.data_api.baris.helpers import make_baris_masks_from_protobuf, BarisMasksProcessor, BARIS_TITLE_DASH
from common.db.mongo.bulk_buffer import BulkBuffer
from common.models.geo import Settlement, Station, Station2Settlement
from common.models.schedule import RThreadType
from common.models.transport import TransportType
from common.models_utils.geo import Point
from common.utils.date import FuzzyDateTime
from travel.rasp.library.python.common23.date.environment import now, now_aware
from travel.rasp.library.python.common23.logging import log_run_time
from common.utils.multiproc import run_instance_method_parallel
from common.utils.tz_mask_split import calculate_mask_for_tz_and_station


log = logging.getLogger(__name__)


class SearchDataGenerator(object):
    def get_points_dicts(self, settlement_from_id, settlement_to_id, t_type_id, station_from_id, station_to_id):
        settlement_from_pk = Point.get_point_key(Settlement, settlement_from_id) if settlement_from_id else None
        settlement_to_pk = Point.get_point_key(Settlement, settlement_to_id) if settlement_to_id else None
        station_from_pk = Point.get_point_key(Station, station_from_id)
        station_to_pk = Point.get_point_key(Station, station_to_id)

        value = {'update_dt': now()}
        d = {(station_from_pk, station_to_pk, t_type_id): value}

        if settlement_to_pk:
            d[(station_from_pk, settlement_to_pk, t_type_id)] = value

        if settlement_from_pk:
            d[(settlement_from_pk, station_to_pk, t_type_id)] = value

        if settlement_from_pk and settlement_to_pk:
            d[(settlement_from_pk, settlement_to_pk, t_type_id)] = value

        return d

    def generate_baris_data(self):
        with log_run_time('generate_baris_data', logger=log):
            settlement_by_station_id = {
                station_id: settlement_id
                for station_id, settlement_id in chain(
                    Station2Settlement.objects.values_list('station_id', 'settlement_id'),
                    Station.objects.filter(t_type=TransportType.PLANE_ID).values_list('id', 'settlement_id'),
                )
            }

            for station_from_id, station_to_id in baris.get_p2p_summary():
                settlement_from_id = settlement_by_station_id.get(station_from_id, None)
                settlement_to_id = settlement_by_station_id.get(station_to_id, None)

                yield self.get_points_dicts(settlement_from_id, settlement_to_id, TransportType.PLANE_ID, station_from_id, station_to_id)

    def generate_rasp_data(self):
        from route_search.models import ZNodeRoute2
        with log_run_time('get znoderoute_objects', logger=log):
            znoderoute_objects = list(
                ZNodeRoute2.objects.all()
                .values_list('settlement_from_id', 'settlement_to_id', 't_type_id', 'station_from_id', 'station_to_id')
                .distinct()
            )

        log.info('znoderoute_objects count: {}'.format(len(znoderoute_objects)))

        for settlement_from_id, settlement_to_id, t_type_id, station_from_id, station_to_id in znoderoute_objects:
            yield self.get_points_dicts(settlement_from_id, settlement_to_id, t_type_id, station_from_id, station_to_id)

    def generate(self, use_baris=False):
        update_dicts = {}
        for d in self.generate_rasp_data():
            update_dicts.update(d)

        if use_baris:
            for d in self.generate_baris_data():
                update_dicts.update(d)

        log.info('update_dicts count: {}'.format(len(update_dicts)))

        with log_run_time('archival_data update', logger=log):
            with BulkBuffer(ArchivalSearchData._get_collection(), max_buffer_size=20000) as coll:
                for (point_from, point_to, t_type_id), update_dict in update_dicts.items():
                    coll.update_one(
                        {
                            'point_from': point_from,
                            'point_to': point_to,
                            'transport_type': t_type_id
                        },
                        {'$set': update_dict},
                        upsert=True
                    )


class SettlementsRaspDataGenerator(object):
    def __init__(self):
        self.settlements_pairs = None
        self.settlements = None
        self.t_types = None
        self.pool_size = int(os.environ.get('RASP_ARCHIVAL_POOL_SIZE', 1))
        self.update_dicts = defaultdict(list)

    def _get_transport_type(self, segment):
        transport = {
            'id': segment.t_type.id,
            'code': segment.t_type.code,
            'title': segment.t_type.L_title(),
        }

        subtype = segment.thread.t_subtype
        if subtype:
            transport['subtype'] = {
                'id': subtype.id,
                'code': subtype.code,
                'title': subtype.L_title(),
                'color': subtype.color.color if subtype.color else None
            }

        return transport

    def _get_station(self, station):
        return {
            'id': station.id,
            'title': station.L_title(),
        }

    def _get_thread(self, thread):
        return {
            'begin_time': thread.begin_time.isoformat() if thread.begin_time else None,
            'end_time': thread.end_time.isoformat() if thread.end_time else None,
            'density': thread.density,
            'number': thread.number
        }

    def _search_settlements_routes(self):
        items_for_worker = int(len(self.settlements_pairs) / self.pool_size) + 1

        workers_data = []
        for i in range(self.pool_size):
            workers_data.append((i * items_for_worker, (i + 1) * items_for_worker))

        total_opers_count = 0
        with log_run_time('search_settlements_routes: run in {} processes'.format(self.pool_size)):
            for update_dicts in run_instance_method_parallel(self._search_routes_parallel, workers_data, self.pool_size):
                log.info('worker done {} operations'.format(len(update_dicts)))
                self.update_dicts.update(update_dicts)
                total_opers_count += len(update_dicts)

            log.info('total operations done: {}'.format(total_opers_count))

    def _calculate_days_by_tz(self, segment):
        if segment.thread.type_id == RThreadType.INTERVAL_ID:
            shift = 0
            thread_start_date = now_aware().astimezone(segment.thread.pytz).date()
        else:
            segment.rtstation_from.thread = segment.thread
            shift = segment.rtstation_from.calc_days_shift(
                event='departure',
                start_date=segment.calculated_start_date
            )
            thread_start_date = segment.calculated_start_date
        days_data = segment.thread.L_days_text_dict(
            shift=shift,
            thread_start_date=thread_start_date,
            show_days=True
        )

        return {segment.station_from.time_zone: days_data}

    def _search_routes_parallel(self, ind_from, ind_to):
        from route_search import shortcuts

        items = self.settlements_pairs[ind_from: ind_to]
        update_dicts = defaultdict(list)

        for settlement_from_id, settlement_to_id in items:
            settlement_from = self.settlements.get(settlement_from_id)
            settlement_to = self.settlements.get(settlement_to_id)
            search_result = shortcuts.search_routes(
                settlement_from,
                settlement_to,
                transport_types=self.t_types
            )

            routes = search_result[0]

            if not routes:
                log.info('routes not found {} {}'.format(settlement_from.point_key, settlement_to.point_key))

            for s in routes:
                update_dicts[(settlement_from.point_key, settlement_to.point_key, s.t_type.id)].append({
                    'title': s.L_title(),
                    'arrival': s.arrival.dt if isinstance(s.arrival, FuzzyDateTime) else s.arrival,
                    'departure': s.departure.dt if isinstance(s.departure, FuzzyDateTime) else s.departure,
                    'station_from': self._get_station(s.station_from),
                    'station_to': self._get_station(s.station_to),
                    'run_days_by_tz': self._calculate_days_by_tz(s),
                    'transport_type': self._get_transport_type(s),
                    'thread': self._get_thread(s.thread)
                })

        return update_dicts

    def generate(self):
        from route_search.models import ZNodeRoute2
        with log_run_time('get znoderoute_objects', logger=log):
            settlements_pairs = list(
                ZNodeRoute2.objects.filter((
                    Q(settlement_from__isnull=False, settlement_to__isnull=False) &
                    ~Q(settlement_from_id=F('settlement_to_id')) &
                    ~Q(t_type_id=TransportType.PLANE_ID)
                )).values_list('settlement_from_id', 'settlement_to_id')
                .distinct()
            )

        log.info('settlements_pair count: {}'.format(len(settlements_pairs)))

        self.settlements_pairs = settlements_pairs
        self.t_types = [t for t in TransportType.objects.all_cached() if t.id != TransportType.PLANE_ID]
        with log_run_time('get settlements', logger=log):
            self.settlements = {s.id: s for s in Settlement.objects.filter(id__in=chain(*self.settlements_pairs))}

        with log_run_time('search', logger=log):
            self._search_settlements_routes()

        log.info('update_dicts count: {}'.format(len(self.update_dicts)))

        return self.update_dicts


class SettlementsBarisDataGenerator(object):
    def __init__(self):
        self.station_by_id = {station.id: station for station in Station.objects.filter(t_type_id=TransportType.PLANE_ID)}
        self.settlement_by_id = {settlement.id: settlement for settlement in Settlement.objects.filter(
            id__in={station.settlement_id for station in self.station_by_id.values() if station.settlement_id})}

        plane = TransportType.objects.get(id=TransportType.PLANE_ID)
        self.transport_type = {
            'id': plane.id,
            'code': plane.code,
            'title': plane.L_title()
        }

    def _get_transport_type(self):
        return self.transport_type

    def _get_station(self, station):
        return {
            'id': station.id,
            'title': station.L_title(),
        }

    def _get_thread(self, number):
        return {
            'number': number
        }

    def _calculate_days_by_tz(self, station_from, mask):
        return {station_from.time_zone: {'days_text': unicode(mask.format_days_text())}}

    def _get_schedule_mask(self, schedule):
        if schedule.Route[0].AirportID not in self.station_by_id:
            return

        from_tz = self.station_by_id[schedule.Route[0].AirportID].pytz
        baris_masks = make_baris_masks_from_protobuf(schedule)
        mask_processor = BarisMasksProcessor(baris_masks, from_tz)

        return mask_processor.run_mask

    def _get_schedule_points(self, schedule):
        points = [
            {
                'station': self.station_by_id.get(point.AirportID),
                'arrival_time': parser.parse(point.ArrivalTime).time() if point.ArrivalTime else None,
                'arrival_day_shift': point.ArrivalDayShift,
                'departure_time': parser.parse(point.DepartureTime).time() if point.DepartureTime else None,
                'departure_day_shift': point.DepartureDayShift,
            }
            for point in schedule.Route if point.AirportID in self.station_by_id
        ]

        return points

    def _get_event_dt(self, event_date, event_time, shift):
        return datetime.combine(event_date, event_time) + timedelta(days=shift)

    def _get_title(self, route):
        first_station, last_station = self.station_by_id[route[0]], self.station_by_id[route[-1]]

        if first_station.settlement_id in self.settlement_by_id:
            first_point = self.settlement_by_id[first_station.settlement_id]
        else:
            first_point = first_station
        if last_station.settlement_id in self.settlement_by_id:
            last_point = self.settlement_by_id[last_station.settlement_id]
        else:
            last_point = last_station

        return '{} {} {}'.format(first_point.L_title(), BARIS_TITLE_DASH, last_point.L_title())

    def _get_schedule_dicts(self, first_date, points, flight_number, run_mask):
        update_dicts = defaultdict(list)

        for point_from, point_to in combinations(points, 2):
            station_from, station_to = point_from['station'], point_to['station']
            settlement_from, settlement_to = station_from.settlement, station_to.settlement
            if not settlement_from or not settlement_to:
                continue

            arrival_dt = self._get_event_dt(first_date, point_to['arrival_time'], point_to['arrival_day_shift'])
            departure_dt = self._get_event_dt(first_date, point_from['departure_time'], point_from['departure_day_shift'])
            from_departure_mask = calculate_mask_for_tz_and_station(
                run_mask, point_from['departure_time'], point_from['departure_day_shift'], station_from.pytz, station_from.pytz
            )

            update_dicts[(settlement_from.point_key, settlement_to.point_key, TransportType.PLANE_ID)].append({
                'title': self._get_title([s['station'].id for s in points]),
                'arrival': arrival_dt,
                'departure': departure_dt,
                'station_from': self._get_station(station_from),
                'station_to': self._get_station(station_to),
                'run_days_by_tz': self._calculate_days_by_tz(station_from, from_departure_mask),
                'transport_type': self._get_transport_type(),
                'thread': self._get_thread(flight_number)
            })

        return update_dicts

    def generate(self):
        with log_run_time('SettlementsBarisDataGenerator.generate', logger=log):
            update_dicts = defaultdict(list)
            for flight in all_flights_iterate():
                for schedule in flight.Schedules:
                    try:
                        run_mask = self._get_schedule_mask(schedule)
                        points = self._get_schedule_points(schedule)
                        flight_number = flight.Title
                        if not points or not run_mask:
                            continue
                    except Exception as ex:
                        log.exception(ex.message)
                        continue

                    dates = run_mask.dates()
                    if not dates:
                        continue
                    first_date = dates[0]

                    schedule_update_dicts = self._get_schedule_dicts(first_date, points, flight_number, run_mask)
                    for k, v in schedule_update_dicts.items():
                        update_dicts[k].extend(v)

            log.info('update_dicts count: {}'.format(len(update_dicts)))

            return update_dicts


class SettlementsDataGenerator(object):
    def generate(self, use_baris=False):
        rasp_data = SettlementsRaspDataGenerator().generate()

        if use_baris:
            baris_data = SettlementsBarisDataGenerator().generate()

            for k, v in baris_data.items():
                rasp_data[k].extend(v)

        update_dt = now()
        week_ago = now() - timedelta(days=7)

        log.info('week: {}'.format(week_ago))

        with log_run_time('archival_data update', logger=log):
            with BulkBuffer(ArchivalSettlementsData._get_collection(), max_buffer_size=1000) as coll:
                for (point_from, point_to, transport_type), update_dict in rasp_data.items():
                    update_data = {
                            'segments': update_dict,
                            'update_dt': update_dt
                    }

                    coll.update_one(
                        {
                            'point_from': point_from,
                            'point_to': point_to,
                            'transport_type': transport_type,
                            'update_dt': {'$lt': week_ago}
                        },
                        {'$set': update_data},
                    )
                    coll.update_one(
                        {
                            'point_from': point_from,
                            'point_to': point_to,
                            'transport_type': transport_type
                        },
                        {'$setOnInsert': update_data},
                        upsert=True
                    )
