# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import logging
from functools import partial

import pymongo

from django.conf import settings

from common.db.mongo import databases
from common.db.mongo.bulk_buffer import BulkBuffer
from common.models.schedule import RThreadType
from travel.rasp.library.python.common23.date import environment
from travel.rasp.library.python.common23.logging import log_run_time
from common.utils.multiproc import run_instance_method_parallel, get_cpu_count
from travel.rasp.info_center.info_center.suburban_notify.changes.find import get_changes_for_interval
from travel.rasp.info_center.info_center.suburban_notify.changes.models import SubscriptionChanges, serialize_sub_changes
from travel.rasp.info_center.info_center.suburban_notify.db import load_db_data, load_points, TRts
from travel.rasp.info_center.info_center.suburban_notify.search import Searcher
from travel.rasp.info_center.info_center.suburban_notify.utils import get_chunks_indexes, disable_logs


log = logging.getLogger(__name__)
log_run_time = partial(log_run_time, logger=log)


class ChangesFinder(object):
    def __init__(self, script_run_id=None):
        self.script_run_id = script_run_id or environment.now_utc().isoformat()
        self.subs = None
        self.searcher = None
        self.day = None

        self.segments_by_sub = None

    def run(self, subs, day, searcher=None):
        self.subs = subs
        if not self.subs:
            return []

        self.searcher = searcher or Searcher()
        self.day = day.date()

        self.init_finder()

        workers = get_cpu_count()
        chunk_len = len(self.subs) // max(1, (workers - 1))
        if chunk_len < 1:
            chunk_len = 1

        msg = 'find changes for {} subs with {} workers, and chunk size {}'.format(len(self.subs), workers, chunk_len)
        with log_run_time(msg):
            args_list = get_chunks_indexes(self.subs, chunk_len)
            subs_changes = []
            for index_from, index_to, subs_changes_from_worker in run_instance_method_parallel(
                    self.find_changes_by_indexes, args_list, pool_size=workers):

                log.info(
                    'find_changes_by_indexes returned: [%s:%s] -> %s',
                    index_from, index_to, len(subs_changes_from_worker)
                )
                subs_changes.extend(subs_changes_from_worker)

            return subs_changes

    def find_changes_by_indexes(self, index_from, index_to):
        with disable_logs():
            subs_changes = []
            items_count = index_to - index_from
            with log_run_time('find_changes_by_indexes: {}: [{}:{}] '.format(items_count, index_from, index_to)):
                for sub in self.subs[index_from:index_to]:
                    sub_changes = self.find_changes_for_sub(sub)
                    if sub_changes:
                        subs_changes.append(sub_changes)

            log.info('returning find_changes_by_indexes %s %s: %s', index_from, index_to, len(subs_changes))
            return index_from, index_to, subs_changes

    def find_changes_for_sub(self, sub):
        segments = self.segments_by_sub.get((sub.point_from_key, sub.point_to_key))
        if segments:
            interval_from, interval_to = sub.get_interval(self.day)
            changes = get_changes_for_interval(interval_from, interval_to, segments)
            sub_changes_dict = serialize_sub_changes(sub, changes, self.day)
            sub_changes = SubscriptionChanges.from_dict(sub_changes_dict)

            sub_changes.subscription = sub

            return sub_changes

    def init_finder(self):
        log.info('init_finder for %s subs', len(self.subs))

        rts_ids = set()
        point_keys = set()
        rts_from_ids = set()
        self.segments_by_sub = {}
        with log_run_time('parse subscriptions & searches'):
            not_found_search_results = 0
            for sub in self.subs:
                point_keys.add(sub.point_from_key)
                point_keys.add(sub.point_to_key)

                key = (sub.point_from_key, sub.point_to_key)
                search_res = self.searcher.search(*key)
                segments = (search_res or {}).get('segments')
                if not segments:
                    not_found_search_results += 1
                    continue

                self.segments_by_sub[key] = segments
                if segments:
                    for rts_from, rts_to in segments:
                        rts_ids.add(rts_from)
                        rts_from_ids.add(rts_from)
                        rts_ids.add(rts_to)
            log.info('Search results not found for %s subs', not_found_search_results)

        with log_run_time('load_db_data'):
            load_db_data(rts_ids)

        with log_run_time('load_points'):
            load_points(point_keys)

        with log_run_time('precache rts from start_date'):
            for rts_from_id in rts_from_ids:
                rts_from = TRts.get(rts_from_id)
                if rts_from.thread.type == RThreadType.BASIC_ID:
                    rts_from.get_start_date_for_event('departure', self.day)  # caching


class ChangesFinderWithStorage(ChangesFinder):
    def __init__(self, script_run_id=None):
        super(ChangesFinderWithStorage, self).__init__(script_run_id=script_run_id)

    @property
    def coll(self):
        return databases[settings.SUBURBAN_NOTIFICATION_DATABASE_NAME].subscription_changes

    def run(self, subs, day, searcher=None):
        self.coll.create_index(
            [
                ('calc_date', pymongo.ASCENDING),
                ('uid', pymongo.ASCENDING),
                ('point_from_key', pymongo.ASCENDING),
                ('point_to_key', pymongo.ASCENDING),
                ('changes.hash', pymongo.ASCENDING),
            ],
            unique=True,
        )
        self.coll.create_index(
            [
                ('point_to_key', pymongo.ASCENDING),
                ('point_from_key', pymongo.ASCENDING),
                ('uid', pymongo.ASCENDING),
            ],
        )

        return super(ChangesFinderWithStorage, self).run(subs, day, searcher=searcher)

    def find_changes_by_indexes(self, index_from, index_to):
        index_from, index_to, subs_changes = (
            super(ChangesFinderWithStorage, self).find_changes_by_indexes(index_from, index_to))

        if subs_changes:
            self.process_sub_changes(subs_changes)
            self.save_subs_changes(subs_changes)

        return index_from, index_to, subs_changes

    def save_subs_changes(self, subs_changes):
        with BulkBuffer(self.coll, max_buffer_size=5000) as buff:
            for sc in subs_changes:
                d = sc.to_dict()
                if not d.get('script_run_id'):
                    d['script_run_id'] = self.script_run_id

                buff.update_one(
                    self._get_mongo_key(sc),
                    {'$set': d},
                    upsert=True,
                )

    def save_subs_changes_all_sent(self, filtered_changes_by_sub):
        with BulkBuffer(self.coll, max_buffer_size=5000) as buff:
            for sub_changes, filtered_changes in filtered_changes_by_sub.items():
                for change in filtered_changes:
                    mongo_key = self._get_mongo_key(sub_changes)
                    mongo_key['changes.hash'] = change.__hash__()
                    buff.update_one(
                        mongo_key,
                        {'$set': {'changes.$.push_sent': self.script_run_id}},
                    )

    def _get_sub_change_key(self, sub_change):
        return (
            sub_change.calc_date,
            sub_change.uid,
            sub_change.point_from_key,
            sub_change.point_to_key,
        )

    def _get_mongo_key(self, sub_change):
        return {
            'calc_date': sub_change.calc_date,
            'uid': sub_change.uid,
            'point_from_key': sub_change.point_from_key,
            'point_to_key': sub_change.point_to_key,
        }

    def process_sub_changes(self, subs_changes):
        if not subs_changes:
            return

        log.info('{} subscription changes'.format(len(subs_changes)))

        subs_changes_by_keys = {}
        mongo_keys = []
        for sub_changes in subs_changes:
            subs_changes_by_keys[self._get_sub_change_key(sub_changes)] = sub_changes
            mongo_keys.append(self._get_mongo_key(sub_changes))

        with log_run_time('find old_subs'):
            keys_buffers = []
            cur_buffer = []
            for key in mongo_keys:
                if len(cur_buffer) == 0:
                    keys_buffers.append(cur_buffer)
                cur_buffer.append(key)
                if len(cur_buffer) == 1000:
                    cur_buffer = []

            old_subs = {}
            for buffer in keys_buffers:
                for old_sub_dict in self.coll.find({'$or': buffer}):
                    old_sub_dict.pop('_id')
                    old_sub_dict.pop('script_run_id')
                    old_sub_changes = SubscriptionChanges.from_dict(old_sub_dict)
                    old_subs[self._get_sub_change_key(old_sub_changes)] = old_sub_changes

            log.info('{} old subscription changes'.format(len(old_subs)))

        not_sent, already_sent = 0, 0
        for key, sub_changes in subs_changes_by_keys.items():
            old_sub_changes = old_subs.get(key)
            if old_sub_changes:
                old_sub_changes_sent = {ch: ch for ch in old_sub_changes.changes if ch.push_sent}
                for change in sub_changes.changes:
                    old_change = old_sub_changes_sent.get(change)
                    if old_change:
                        change.push_sent = old_change.push_sent
                        already_sent += 1
                    else:
                        not_sent += 1
        log.info('{} changes already sent, {} changes not sent'.format(already_sent, not_sent))

    def load(self, query, limit=0):
        subs_changes = []
        for sub_change_dict in self.coll.find(query or {}, {'_id': 0}).limit(limit):
            sub_changes = SubscriptionChanges.from_dict(sub_change_dict)
            subs_changes.append(sub_changes)

        return subs_changes

    def get_subs_changes(self, uid, point_from_key, point_to_key):
        changes = self.coll.find({
            'uid': uid,
            'point_from_key': point_from_key,
            'point_to_key': point_to_key
        })

        subs_changes = []
        for sub_changes in changes:
            sub_changes.pop('_id')
            sub_changes.pop('script_run_id')
            subs_changes.append(SubscriptionChanges.from_dict(sub_changes))

        return subs_changes
