#!/usr/bin/env python
# coding: utf-8

from __future__ import print_function

import os
import re
import logging

from pprint import pformat

from uuid import UUID
from pymongo import MongoClient, read_preferences, ReadPreference

from datetime import datetime
from itertools import izip_longest

FORMAT = '%(asctime)-15s %(process)d %(message)s'
logging.basicConfig(format=FORMAT, level=logging.INFO)
logger = logging.getLogger(__name__)

GROUP_BY_NUM = 50
MAX_PROGRESS = 100
DB_NAME = 'advisor_db'
CONNECT_TIMEOUT_MS = 3 * 60 * 1000  # 3 minutes
SERVER_SELECTION_TIMEOUT_MS = 5 * 60 * 1000  # 5 minutes


def get_connection(config_entry, environ):
    MONGO_CONFIG = {
        'production': {
            'user': environ.get('MONGO_USER'),
            'password': environ.get('MONGO_REPLICA_PASSWORD'),
            'hosts': ','.join([
                'iva-g4rokm43mg6rjizc.db.yandex.net:27018',
                'vla-8h8moxom5mcqvhxq.db.yandex.net:27018',
                'sas-350atkewrtqdqznw.db.yandex.net:27018',
            ]),
            'database': DB_NAME,
            'read_preferences': read_preferences.Secondary(),
        },
    }
    config = MONGO_CONFIG[config_entry]
    conn_str = 'mongodb://{user}:{password}@{hosts}/{database}?maxpoolsize=16'.format(**config)
    return MongoClient(conn_str,
                       read_preference=config['read_preferences'],
                       connectTimeoutMS=CONNECT_TIMEOUT_MS,
                       serverSelectionTimeoutMS=SERVER_SELECTION_TIMEOUT_MS,
                       )[config['database']]


def get_collections(environ):
    src_db = get_connection('production', environ)
    dst_db = get_connection('stress', environ)

    collections = {
        'src': {
            'profile': src_db['profile'],
            'client': src_db['client'],
        },
        'dst': {
            'profile': dst_db['profile'],
            'client': dst_db['client'],
        },
    }

    return collections


RE_UUID_PATTERN = re.compile(r'^[\da-f]{32}$', re.IGNORECASE)


def grouper(iterable, n):
    """ Collect data into fixed-length chunks or blocks """
    # grouper('ABCDEFG', 3) --> ABC DEF GNoneNone
    arguments = [iter(iterable)] * n
    return izip_longest(fillvalue=None, *arguments)


def save_documents(data, collections):
    for collection_name, collection in collections['dst'].items():
        documents = data[collection_name]
        if documents:
            ids = [doc['_id'] for doc in documents]
            collection.delete_many({'_id': {'$in': ids}})
            collection.insert_many(documents, ordered=False)
        else:
            logger.warning('No documents have been found in %r collection!', collection_name)


def save_current_state(state_file_name, last_processed_index):
    with open(state_file_name, 'wt') as sf:
        sf.write(str(last_processed_index))


def copy_uuids(uuids_file_name, environ, read_delay=0):
    logger.info('Reading uuids file: %s', uuids_file_name)
    uuids = open(uuids_file_name).read().split('\n')
    uuids = filter(lambda uuid: RE_UUID_PATTERN.match(uuid), uuids)

    state_file_name = uuids_file_name + '.state'
    try:
        logger.info('Reading uuids state file: %r', state_file_name)
        last_processed_index = int(open(state_file_name).read())
        uuids = uuids[last_processed_index + 1:]
        logger.info('Continue copying from line %d (%s).', last_processed_index, uuids[0])
    except IOError:
        logger.info('%r file is not exists. ok.', state_file_name)

    started_at = datetime.now()
    try:
        collections = get_collections(environ=environ)
        client_coll = collections['src']['client']
        profile_coll = collections['src']['profile']
        for i, uuid_group in enumerate(grouper(uuids, GROUP_BY_NUM), 1):
            uuid_clients = [UUID(uuid) for uuid in uuid_group if uuid is not None]
            if i > 1 and read_delay:
                time.sleep(read_delay / 1000.)
            client_docs = list(client_coll.find({'_id': {'$in': uuid_clients}}))
            if client_docs:
                uuid_devices = [doc['device_id'] for doc in client_docs if 'device_id' in doc]
                wrong_clients = [doc for doc in client_docs if 'device_id' not in doc]
                if wrong_clients:
                    logger.error('client profile(s) w/o [device_id] value:\n%s' % pformat(wrong_clients))
                    wrong_clients = None

                profile_docs = list(profile_coll.find({'_id': {'$in': uuid_devices}}))
                data = {'profile': profile_docs, 'client': client_docs}
                save_documents(data, collections=collections)
            else:
                logger.warning('No client profiles have been found!')

            save_current_state(state_file_name, uuids.index(uuid_clients[-1].hex))

            if (i % MAX_PROGRESS) == 0:
                elapsed = datetime.now() - started_at
                logger.info('... {:7,} of {:,} elapsed: {}'.format(i * GROUP_BY_NUM, len(uuids), elapsed))
        os.remove(state_file_name)

    except Exception:
        logger.exception('Something went wrong:')
        raise

    if (i % MAX_PROGRESS) != 0:
        elapsed = datetime.now() - started_at
        logger.info('... {:7,} of {:,} elapsed: {}'.format(len(uuids), len(uuids), elapsed))


if __name__ == '__main__':
    import sys
    import time
    import argparse

    parser = argparse.ArgumentParser(description='Copy device documents from source collection to destination')
    parser.add_argument('--file', required=True,
                        help='name of file contained unique uuids')
    parser.add_argument('--read_delay', required=False, type=int, default=0,
                        help='add delay between read operations in ms')
    args = parser.parse_args()

    logger.info('Started')

    # try several attempts to get it done
    for attempt, error_timeout in enumerate([10, 30, 60, None], 1):
        try:
            copy_uuids(args.file, os.environ, args.read_delay)
            break
        except IOError as e:
            logger.exception('Could not copy profiles.')
            sys.exit(-1)
        except Exception as e:
            logger.exception('Something goes wrong.\n' + str(e))
            if error_timeout:
                logger.info('sleeping %d seconds', error_timeout)
                time.sleep(error_timeout)
            else:
                logger.error('Failed')
                sys.exit(-2)

    logger.info('Done')
