# coding: utf8
from __future__ import unicode_literals, absolute_import, division, print_function

import json
import time
import traceback
from datetime import timedelta
from dateutil import parser

import ibm_db
import mongolock
import pytz
import requests
from django.conf import settings
from mongoengine.queryset import NotUniqueError

from common.apps.suburban_events.models import LVGD01_TR2PROC, LVGD01_TR2PROC_query, LVGD01_TR2PROC_feed
from common.apps.suburban_events.logs import log_rzd_raw_data
from common.apps.suburban_events.scripts.update_companies_crashes import calc_companies_crashes
from common.db.sqs.instance import sqs_client
from common.dynamic_settings.core import DynamicSetting
from common.dynamic_settings.default import conf
from travel.rasp.library.python.common23.date import environment
from common.utils.iterrecipes import chunker
from common.utils.lock import lock
from travel.rasp.library.python.common23.logging import log_run_time
from travel.rasp.library.python.common23.logging.scripts import script_context
from travel.rasp.library.python.sqs.queue import FIFOQueue
from travel.rasp.suburban_tasks.suburban_tasks.cpy_pst import get_script_logger
from travel.rasp.suburban_tasks.suburban_tasks.models import LVGD01_TR2PROC_feed_mysql
from travel.rasp.suburban_tasks.suburban_tasks.rzd_utils import get_connect, get_rowdicts_from_resource, rzd_sql_callproc, rzd_db_manager


log = get_script_logger()

MSK_TZ = pytz.timezone('Europe/Moscow')
FETCH_TRTYPE = 2  # suburban
FETCH_DATE_FORMAT = '%Y-%m-%d-%H.%M'


conf.register_settings(
    SUBURBAN_EVENTS_GET_RZD_DATA_RETRY_DELAY=DynamicSetting(
        20, cache_time=1,
        description='Сколько ждем перед повторной попыткой получить данные о хождении электричек от РЖД'),
    SUBURBAN_EVENTS_GET_DATA_MAX_RETRIES=DynamicSetting(
        10, cache_time=1,
        description='Максимальное количество попыток получить данные данные о хождении электричек от РЖД'),
)


def fetch_data(dt_from, dt_to, retries=3):
    dt_from_str = dt_from.strftime(FETCH_DATE_FORMAT)
    dt_to_str = dt_to.strftime(FETCH_DATE_FORMAT)

    for tries_count in range(1, retries + 1):
        try:
            with rzd_db_manager():
                connect = get_connect(db_name=settings.RZD_SUBURBAN_EVENTS_DB)
                result = rzd_sql_callproc(connect, '@LVGD01.TR2PROC', (FETCH_TRTYPE, dt_from_str, dt_to_str))
                resource = result[0]
                rowdicts = get_rowdicts_from_resource(resource)
                ibm_db.commit(connect)
                return rowdicts, tries_count
        except Exception:
            log.exception(u'Unable to fetch data from rzd')
            if tries_count == retries:
                raise
            time.sleep(conf.SUBURBAN_EVENTS_GET_RZD_DATA_RETRY_DELAY)


def prepare_rows_for_saving(rows, query_info):
    """ Make each dict from rows valid to be kwargs for LVGD01_TR2PROC model """
    for row_dict in rows:
        row_dict['query'] = query_info
        row_dict['ID_TRAIN'] = row_dict.pop('ID')
        for field, value in row_dict.items():
            # this fields aren't converted automatically by ibm_db driver
            if field in ['PRSTOP', 'PRIORITY', 'PRIORITY_RATING']:
                row_dict[field] = float(value)
            elif isinstance(value, basestring):
                row_dict[field] = value.strip()


def lists_diff(base_list, other_list):
    """ Return elements from base_list which are not in other_list.
        Useful for objects that have no meaningful __hash__ method (to use set.difference or alike)
    """
    base_iter = iter(sorted(base_list))
    other_iter = iter(sorted(other_list))
    diff = []
    try:
        other_value = next(other_iter)
        for base_value in base_iter:
            while base_value > other_value:
                try:
                    other_value = next(other_iter)
                except StopIteration:
                    diff.append(base_value)
                    raise

            if base_value == other_value:
                continue

            diff.append(base_value)
    except StopIteration:
        diff.extend(base_iter)

    return diff


def save_events_to_feed(events):
    if not len(events):
        log.info(u'No events to save, skipping')
        return

    log.info(u'Saving {} events to feed'.format(len(events)))
    rows_to_save = [LVGD01_TR2PROC_feed(**d) for d in events]
    LVGD01_TR2PROC_feed.objects.insert(rows_to_save)

    try:
        calc_companies_crashes(rows_to_save)
    except Exception as ex:
        log.exception(u"Can't calc companies crashes: {}".format(repr(ex)))

    for event in events:
        event.pop('query', None)

    try:
        log.info(u'Saving {} rows to mysql feed'.format(len(events)))
        LVGD01_TR2PROC_feed_mysql.objects.bulk_create(LVGD01_TR2PROC_feed_mysql(**e) for e in events)
    except Exception:
        log.exception(u"Can't save new events to mysql")
    else:
        log.info(u"Saving to mysql done")


def send_rows_to_remote_server(rows):
    """ Используется для пересылки данных РЖД из прода в тестинг. """

    if not settings.SUBURBAN_EVENTS_SEND_TO_SQS:
        return

    events = []
    for row in rows:
        row_dict = row.to_dict()
        row_dict.pop('id', None)
        query = row_dict.pop('query', None)
        row_dict['TIMEOPER_N'] = row_dict['TIMEOPER_N'].isoformat()
        row_dict['TIMEOPER_F'] = row_dict['TIMEOPER_F'].isoformat()

        if query:
            row_dict['query'] = query.to_dict()
            row_dict['query']['queried_at'] = row_dict['query']['queried_at'].isoformat()
            row_dict['query']['query_to'] = row_dict['query']['query_to'].isoformat()
            row_dict['query']['query_from'] = row_dict['query']['query_from'].isoformat()
            row_dict['query']['id'] = str(row_dict['query']['id'])

        events.append(row_dict)

    try:
        queue_name = settings.SUBURBAN_EVENTS_QUEUE_NAME
        queue = FIFOQueue(sqs_client, queue_name)

        for chunk in chunker(events, settings.SUBURBAN_EVENTS_QUEUE_MAX_BATCH_SIZE):
            queue.push(json.dumps(chunk))

    except Exception:
        log.exception(u"Can't send events to queue {}".format(queue_name))
    else:
        log.info(u"Sent {} events to {}".format(len(events), queue_name))


def read_rows_from_sqs():
    queue_name = settings.SUBURBAN_EVENTS_QUEUE_NAME
    queue = FIFOQueue(sqs_client, queue_name)

    try:
        for events, query in read_events_by_query(queue):
            save_events_to_feed(events)
            log.info(u"Read {} rzd events selected by query {}".format(len(events), query))
    except Exception:
        log.error(u"Can't read events from queue {}".format(queue_name))
        raise


def read_events_by_query(queue):
    try:
        messages_count = 0
        last_query = None
        events_buffer = []

        for message, token in queue.receive_messages():
            messages_count += 1
            events = json.loads(message)
            events, query = fix_events_query(events)

            queue.delete_message(token)

            if last_query is not None and query != last_query:
                if events_buffer:
                    yield events_buffer, last_query

                events_buffer = events
            else:
                events_buffer.extend(events)

            last_query = query

        if events_buffer:
            yield events_buffer, last_query
    finally:
        log.info(u"Read total {} messages from queue {}".format(messages_count, queue.queue_name))


def fix_events_query(events):
    query = events[0].get('query', None)

    if query:
        query['queried_at'] = parser.parse(query['queried_at'])
        query['query_from'] = parser.parse(query['query_from'])
        query['query_to'] = parser.parse(query['query_to'])
        query['source'] = 'uploaded'
        try:
            query = LVGD01_TR2PROC_query.objects.create(**query)
        except NotUniqueError:
            query = LVGD01_TR2PROC_query.objects.get(id=query['id'])

    for event in events:
        event.pop('_id', None)
        event.pop('query', None)
        event['TIMEOPER_F'] = parser.parse(event['TIMEOPER_F'])
        event['TIMEOPER_N'] = parser.parse(event['TIMEOPER_N'])

        if query:
            event['query'] = query

    return events, query


def is_fetch_possible():
    if not conf.SUBURBAN_RZD_FETCH_ENABLED:
        log.info(u'rzd fetch is disabled')
        return False

    query = LVGD01_TR2PROC_query.objects.filter(exception=None).order_by('-queried_at').first()
    if not query:
        return True

    msk_now = environment.now()
    can_fetch = msk_now >= query.queried_at + timedelta(seconds=conf.SUBURBAN_RZD_MIN_FETCH_INTERVAL)

    if not can_fetch:
        log.info(u"Can fetch only once per {} seconds. Last succesfull fetch was at {}".format(
            conf.SUBURBAN_RZD_MIN_FETCH_INTERVAL, query.queried_at
        ))
        return False

    return True


def get_rzd_data():
    if not is_fetch_possible():
        return

    msk_now = environment.now()
    dt_to = (msk_now - timedelta(seconds=conf.SUBURBAN_RZD_SHIFT_FROM_NOW_TO_FETCH)).replace(tzinfo=None)
    prev_query = LVGD01_TR2PROC_query.objects.filter(exception=None).order_by('-id').first()
    if prev_query:
        log.info(u'Previous query is {} with query_to {}'.format(prev_query.id, prev_query.query_to))
        dt_from = prev_query.query_to - timedelta(seconds=conf.SUBURBAN_RZD_FETCH_RANGE_OVERLAP)
        if dt_to - dt_from > timedelta(seconds=conf.SUBURBAN_RZD_MAX_FETCH_TO_PAST):
            dt_from = dt_to - timedelta(seconds=conf.SUBURBAN_RZD_MAX_FETCH_TO_PAST)
    else:
        dt_from = dt_to - timedelta(seconds=conf.SUBURBAN_RZD_MAX_FETCH_TO_PAST)

    query = LVGD01_TR2PROC_query(
        queried_at=msk_now.replace(tzinfo=None),
        query_from=dt_from,
        query_to=dt_to)
    query.save()

    try:
        try:
            with log_run_time(u'Fetch data for {} - {}'.format(dt_from, dt_to), logger=log) as t:
                rows, tries_count = fetch_data(dt_from, dt_to, retries=conf.SUBURBAN_EVENTS_GET_DATA_MAX_RETRIES)
        finally:
            query.time_taken = t()

        prepare_rows_for_saving(rows, query)
        new_rows = [LVGD01_TR2PROC(**d) for d in rows]
        LVGD01_TR2PROC.objects.insert(new_rows)
        query.rows_count = len(rows)
        query.tries_count = tries_count
        log.info(u'Saving {} rows for query {}'.format(query.rows_count, query.id))
        query.save()

        if prev_query:
            prev_rows = LVGD01_TR2PROC.objects.filter(query=prev_query)
        else:
            prev_rows = []

        actually_new_rows = lists_diff(new_rows, prev_rows)
        log.info(u'Found {} actually new rows (no doubles with previous query)'.format(len(actually_new_rows)))
        query.new_rows_count = len(actually_new_rows)

        rows_dicts = [row.to_dict() for row in actually_new_rows]

        # ObjectId не нужен для всех остальных операций здесь
        for rd in rows_dicts:
            rd.pop('id', None)

        save_events_to_feed(rows_dicts)

        now = environment.now()
        log_rzd_raw_data(now.isoformat(), now, rows_dicts)

        send_rows_to_remote_server(actually_new_rows)
    except Exception:
        log.exception(u'Unable to get data')
        query.exception = traceback.format_exc()

    query.save()


def run():
    try:
        with lock('lock_fetch_events_from_rzd', 'suburban_events_process', database_name=settings.SUBURBAN_EVENTS_DATABASE_NAME):
            with script_context('update_suburban_events', report_progress=False):
                get_rzd_data()
    except mongolock.MongoLockLocked as ex:
        log.debug('Can not get lock: %s', repr(ex))


if __name__ == '__main__':
    run()
