# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
from functools import partial
from common.utils.multiproc import get_cpu_count

import pymongo
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist

from common.db.mongo import databases
from common.db.mongo.bulk_buffer import BulkBuffer
from common.models.geo import Station
from common.models.schedule import RTStation, RThread
from common.models.transport import TransportType
from common.models_utils.geo import Point
from travel.rasp.library.python.common23.logging import log_run_time
from common.utils.multiproc import run_instance_method_parallel
from common.utils.mysql_try_hard import mysql_try_hard
from geosearch.views.pointtopoint import process_points_lists, PointList
from route_search.base import get_threads_from_znoderoute, PlainSegmentSearch
from travel.rasp.info_center.info_center.suburban_notify.utils import disable_logs, get_chunks_indexes

log = logging.getLogger(__name__)
log_run_time = partial(log_run_time, logger=log)

suburban_type = TransportType.objects.get(id=TransportType.SUBURBAN_ID)


def get_mczk_stations():
    mczk_thread_ids = []
    threads_q = RThread.objects.filter(t_type=suburban_type)
    for thread_id, thread_uid in threads_q.only('id', 'uid').values_list('id', 'uid'):
        if 'MCZK' in thread_uid:
            mczk_thread_ids.append(thread_id)

    station_ids = set()
    stations_q = RTStation.objects.filter(thread_id__in=mczk_thread_ids)
    for station_id in stations_q.only('station_id').values_list('station_id', flat=True):
        station_ids.add(station_id)

    return station_ids


class Searcher(object):
    """
    Finding segments an all days for subscriptions
    """
    def __init__(self):
        self.searches_cache = {}
        self.precached = False

        self.point_pairs = None
        self.mczk_stations = None

        with log_run_time('get_mczk_stations'):
            self.mczk_stations = get_mczk_stations()
        log.info('MCZK stations: %s', len(self.mczk_stations))

    def get_point_by_key(self, point_key):
        try:
            return Point.get_by_key(point_key)
        except ObjectDoesNotExist:
            return None

    def get_point_pairs(self, subs):
        point_pairs, skipped_points = set(), set()
        for sub in subs:
            point_from = self.get_point_by_key(sub.point_from_key)
            if not point_from:
                skipped_points.add(sub.point_from_key)
                continue

            point_to = self.get_point_by_key(sub.point_to_key)
            if not point_to:
                skipped_points.add(sub.point_to_key)
                continue

            point_pairs.add((point_from, point_to))

        log.info('get_point_pairs: not found points: %s', len(skipped_points))

        return point_pairs

    def precalc_searches(self, subs, pool_size=None):
        pool_size = pool_size or get_cpu_count()

        with log_run_time('get_point_pairs'):
            self.point_pairs = list(self.get_point_pairs(subs))
        log.info('point_pairs: %s', len(self.point_pairs))

        if not self.point_pairs:
            self.precached = True
            return

        chunk_len = len(self.point_pairs) // max(1, (pool_size - 1))
        if chunk_len < 1:
            chunk_len = 1

        with log_run_time('precalc_searches in parallel on {} workers'.format(pool_size)):
            args_list = get_chunks_indexes(self.point_pairs, chunk_len)
            r = run_instance_method_parallel(self.generate_segments_by_indexes, args_list, pool_size=pool_size)
            for idx_from, idx_to, segements in r:
                log.info('worker done: %s %s', idx_from, idx_to)

                for seg in segements:
                    key = (seg['point_from_old'], seg['point_to_old'])
                    self.searches_cache[key] = seg

        self.precached = True

    def generate_segments_by_indexes(self, idx_from, idx_to):
        with disable_logs():
            log.info('Starting searches [{}:{}]'.format(idx_from, idx_to))
            point_pairs = self.point_pairs[idx_from:idx_to]

            return idx_from, idx_to, list(self.generate_segments(point_pairs))

    def generate_segments(self, point_pairs):
        with log_run_time('modify_point_pairs', logger=log, log_level=logging.DEBUG):
            new_pp_by_pp = self.modify_point_pairs(point_pairs)
        with log_run_time('get_znr_segments', logger=log, log_level=logging.DEBUG):
            znr_segments = get_znr_segments(new_pp_by_pp.values())
        with log_run_time('filter_znr_segments', logger=log, log_level=logging.DEBUG):
            znr_segments = filter_znr_segments(znr_segments)

        with log_run_time('save_segments_for_point_pairs', logger=log, log_level=logging.DEBUG):
            return list(self.gen_search_results(new_pp_by_pp, znr_segments))

    def gen_search_results(self, new_pp_by_pp, segments_by_new_pp):
        for (point_from_old, point_to_old), (point_from_new, point_to_new) in new_pp_by_pp.items():
            segments = segments_by_new_pp[(point_from_new, point_to_new)]

            yield {
                'point_from_old': point_from_old.point_key,
                'point_to_old': point_to_old.point_key,
                'point_from_new': point_from_new.point_key,
                'point_to_new': point_to_new.point_key,
                'segments': [
                    [segment.rtstation_from.id, segment.rtstation_to.id]
                    for segment in segments
                ]
            }

    def search(self, point_from_key, point_to_key):
        point_pair = (point_from_key, point_to_key)

        if self.precached:
            search_res = self.searches_cache.get(point_pair)
            if search_res:
                return search_res
        else:
            point_from = Point.get_by_key(point_from_key)
            point_to = Point.get_by_key(point_to_key)

            return list(self.generate_segments([(point_from, point_to)]))[0]

    def is_mczk(self, point):
        return isinstance(point, Station) and point.id in self.mczk_stations

    def modify_point_pairs(self, point_pairs):
        res = {}
        mczk_skips = 0
        for point_from, point_to in point_pairs:
            point_list_from = PointList(point_from, [point_from], None, exact_variant=True)
            point_list_to = PointList(point_to, [point_to], None, exact_variant=True)

            try:
                point_from_new, point_to_new = process_points_lists(point_list_from, point_list_to, suburban=True)
            except Exception as ex:
                log.warning('process_points_lists error %s %s %s', repr(ex), point_from, point_list_to)
                continue

            point_from_new, point_to_new = point_from_new.point, point_to_new.point
            if self.is_mczk(point_from_new) or self.is_mczk(point_to_new):
                mczk_skips += 1
                continue

            res[(point_from, point_to)] = (point_from_new, point_to_new)

        log.debug('mczk skipped: %s', mczk_skips)

        return res


@mysql_try_hard
def get_znr_segments(points_pairs):
    res = {}
    for point_from, point_to in points_pairs:
        prepared_threads = get_threads_from_znoderoute(point_from, point_to, suburban_type)
        res[(point_from, point_to)] = prepared_threads

    return res


def filter_znr_segments(threads_for_pp):
    res = {}
    for (point_from, point_to), prepared_threads in threads_for_pp.items():
        search = PlainSegmentSearch(point_from, point_to, suburban_type, prepared_threads=prepared_threads)
        segments = search.all_days_search()
        res[(point_from, point_to)] = segments

    return res


class SearcherWithStorage(Searcher):
    def __init__(self, coll=None):
        super(SearcherWithStorage, self).__init__()
        self.coll = coll or databases[settings.SUBURBAN_NOTIFICATION_DATABASE_NAME].searches_cache

    def precalc_searches(self, subs, pool_size=None):
        self.coll.drop()

        super(SearcherWithStorage, self).precalc_searches(subs, pool_size=pool_size)

        with log_run_time('create_index for search'):
            self.coll.create_index([("point_from_old", pymongo.ASCENDING), ("point_to_old", pymongo.ASCENDING)])

    def generate_segments_by_indexes(self, idx_from, idx_to):
        with BulkBuffer(self.coll, max_buffer_size=2000, logger=log) as coll:
            idx_from, idx_to, segments = super(SearcherWithStorage, self).generate_segments_by_indexes(idx_from, idx_to)
            for res in segments:
                coll.insert_one(res)

            return idx_from, idx_to, segments

    def load_searches(self, subs, limit=0):
        if len(subs) > 3000:
            query = {}
        else:
            query = {'$or': [
                {
                    'point_from_old': sub.point_from_key,
                    'point_to_old': sub.point_to_key,
                }
                for sub in subs
            ]}

        with log_run_time('get searches'):
            searches = list(self.coll.find(query).limit(limit))

            for search in searches:
                self.searches_cache[(search['point_from_old'], search['point_to_old'])] = search

        self.precached = True
