# encoding: utf-8
from __future__ import unicode_literals

import json
import logging
import boto3
import os
import sys
import time
from datetime import datetime
from functools import partial
from multiprocessing import Pool

import requests
from tornado.ioloop import IOLoop
from frozendict import frozendict

from intranet.webauth.lib.role_cache import (
    write_role_cache,
    get_cache_version,
    remove_version,
    set_cache_version,
    write_version_timestamp,
)
from intranet.webauth.lib.settings import (
    WEBAUTH_IDM_CACHE_FOLDER,
    WEBAUTH_OAUTH_TOKEN,
    WEBAUTH_IDM_HOST,
    WEBAUTH_IDM_PAGE_SIZE,
    WEBAUTH_USE_MDS,
    YENV_TYPE,
    WEBAUTH_ACCESS_KEY_ID,
    WEBAUTH_SECRET_ACCESS_KEY,
)
from intranet.webauth.lib.utils import (
    setup_logging,
    requests_with_retry_policy,
    lock_file,
)

setup_logging()
logger = logging.getLogger('generate_idm_cache')

CA_CERTIFICATES_PATH = '/etc/ssl/certs/ca-certificates.crt'
IDM_GETROLES_API = '/api/v1/roles/'
IDM_GET_NODES_API = '/api/v1/rolenodes/'
IDM_GET_SYSTEMS_API = '/api/v1/systems/'
IDM_HEADERS = {
    'Content-Type': 'application/json',
    'Authorization': 'OAuth %s' % WEBAUTH_OAUTH_TOKEN
}
IDM_TIMEOUT = 300
IDM_TRIES = 7
OUTPUT_DIR = WEBAUTH_IDM_CACHE_FOLDER
CONCURENT_PROCESS = int(os.environ.get('GENERATE_IDM_CACHE_POOL_SIZE', 4))
REDIS_WRITE_BATCH_SIZE = 10
REDIS_WRITE_SLEEP_TIME = 0.0002  # 200 μs
REDIS_RETRIES = 15


def make_dir(dir):
    try:
        if not os.path.isdir(dir):
            logger.debug("generate_idm_cache: Make dir %s", dir)
            os.makedirs(dir, mode=0o755)
        return True
    except OSError as e:
        # Error num 17 - directory exists.
        if e.errno == 17:
            return True
        logger.warn("generate_idm_cache: Can't create dir %s. Details: %s.", dir, e)

    return False


def get_webauth_idm_systems():
    current_endpoint = "{}?use_webauth=true".format(IDM_GET_SYSTEMS_API)

    result = []

    while current_endpoint is not None:
        try:
            url = "https://%s%s" % (WEBAUTH_IDM_HOST, current_endpoint)
            session = requests_with_retry_policy(IDM_TRIES)
            response = session.get(url, headers=IDM_HEADERS, timeout=IDM_TIMEOUT, verify=CA_CERTIFICATES_PATH)
            response.raise_for_status()
        except requests.exceptions.RequestException:
            logger.error("generate_idm_cache: All attemps to get batch from IDM had failed")
            raise

        resp = response.json()
        systems = resp['objects']
        result.extend([system.get('slug') for system in systems])

        current_endpoint = resp['meta'].get('next')

    return result


def cache_webauth_idm_systems(webauth_idm_systems):
    output_path = os.path.join(WEBAUTH_IDM_CACHE_FOLDER, 'webauth_idm_systems.json')
    with lock_file(output_path, blocking=True) as lock:
        if lock.locked:
            logger.info('generate_idm_cache: WEBAUTH_IDM_SYSTEMS saving data to %s', output_path)
            json.dump(webauth_idm_systems, open(output_path, 'w'))
        else:
            logger.warn("generate_idm_cache: Could not acquire lock on %s", output_path)


def get_roles_batch(system, limit, offset=0, keep_meta=False, no_meta=False):
    url = "https://%s%s" % (WEBAUTH_IDM_HOST, IDM_GETROLES_API)
    query = {'system': system,
             'limit': limit,
             'type': 'active',
             'ownership': 'personal',
             'offset': offset,
             'for_webauth': True}
    if no_meta:
        query['no_meta'] = True
    try:
        session = requests_with_retry_policy(IDM_TRIES)
        response = session.get(url, params=query, headers=IDM_HEADERS,
                               timeout=IDM_TIMEOUT, verify=CA_CERTIFICATES_PATH)
        response.raise_for_status()
    except requests.exceptions.RequestException:
        logger.error("generate_idm_cache: All attemps to get batch from IDM had failed")
        return None

    resp = json.loads(response.text)
    if keep_meta:
        return resp
    else:
        nodes = {}
        roles = resp.get('objects', [])
        for role in roles:
            node_path = os.path.join(system, role['node']['value_path'])
            nodes[node_path] = nodes.get(node_path, [])
            node = nodes[node_path]
            user = role['user']['username']
            fields_data = frozendict(role['fields_data'] or {})
            node.append((user, fields_data))
        return nodes


def frozen_dict_dumps(obj):
    # frozendict по умолчанию не сериализуется в json
    if isinstance(obj, frozendict):
        return dict(obj)
    raise TypeError('Object of type %s is not json serializable', type(obj).__name__)


def merge_files(system, blocks):
    nodes = {}
    for block in blocks:
        for node in block:
            nodes[node] = nodes.get(node, [])
            nodes[node].extend(block[node])

    for node in nodes:
        nodes[node] = list(set(nodes[node]))

    write_system(system, nodes)


def get_batch_wrapper(args):
    system, limit, offset = args
    return get_roles_batch(system, limit, offset, no_meta=True)


def update_system(idm_system):
    info_batch = get_roles_batch(idm_system, limit=1, keep_meta=True)
    total_count = info_batch['meta']['total_count']
    logger.info('generate_idm_cache: System \"%s\": %d roles in total', idm_system, total_count)
    count, rest = divmod(total_count, WEBAUTH_IDM_PAGE_SIZE)
    limits = [WEBAUTH_IDM_PAGE_SIZE] * count
    if rest != 0:
        limits.append(rest)
        count += 1
    offsets = [0] * count
    for i in range(1, count):
        offsets[i] = offsets[i - 1] + limits[i - 1]
    pool = Pool(CONCURENT_PROCESS)
    processed = 0
    results = []
    for i, result in enumerate(pool.imap_unordered(get_batch_wrapper, zip([idm_system] * count, limits, offsets))):
        processed += limits[i]
        logger.debug("generate_idm_cache: Processed %d out of %d roles", processed, total_count)
        results.append(result)
    if any(result is None for result in results):
        logger.error("generate_idm_cache: An error occured while getting roles from IDM for system %s", idm_system)
        return
    merge_files(idm_system, results)


def write_system(system, data):
    output_path = os.path.join(WEBAUTH_IDM_CACHE_FOLDER, '{}.json'.format(system))
    with lock_file(output_path, blocking=True) as lock:
        if lock.locked:
            logger.info('generate_idm_cache: System %s: saving data to %s', system, output_path)
            with open(output_path, 'w') as f:
                json.dump(data, f, default=frozen_dict_dumps)
        else:
            logger.warn("generate_idm_cache: Could not acquire lock on %s", output_path)


def get_data_from_mds():
    session = boto3.session.Session(
        aws_access_key_id=WEBAUTH_ACCESS_KEY_ID,
        aws_secret_access_key=WEBAUTH_SECRET_ACCESS_KEY,
    )
    s3 = session.client(
        service_name='s3',
        endpoint_url='https://s3.mds.yandex.net',
    )
    get_object_response = s3.get_object(
        Bucket='webauth-roles',
        Key='{}_webauth_roles_latest'.format(YENV_TYPE),
    )
    return json.loads(get_object_response['Body'].read())


def load_remote_data(webauth_idm_systems):
    logger.info('Loading remote data')
    logging.getLogger("requests").setLevel(logging.WARNING)
    logging.getLogger("urllib3").setLevel(logging.WARNING)
    start_time = time.time()
    if not make_dir(WEBAUTH_IDM_CACHE_FOLDER):
        sys.exit(1)
    if WEBAUTH_USE_MDS:
        logger.info('Loading remote data from mds')
        roles_data = get_data_from_mds()
        logger.info('Loaded remote data from mds')
        for system, data in roles_data.iteritems():
            write_system(system, data)
    else:
        logger.info('Loading remote data from api')
        for idm_system in webauth_idm_systems:
            update_system(idm_system)
        logger.info('Loaded remote data from api')
    duration = int(time.time() - start_time)
    logger.info("generate_idm_cache: IDM cache file was generating for %dm %ds", *divmod(duration, 60))


def upload_system(system, cache_version):
    output_path = os.path.join(WEBAUTH_IDM_CACHE_FOLDER, '{}.json'.format(system))
    with lock_file(output_path, blocking=True):
        with open(output_path) as f:
            nodes = json.load(f)

    batch = []
    for node in nodes:
        for user, fields_data in nodes[node]:
            if not isinstance(fields_data, dict):
                fields_data = dict(fields_data)
            role = '/{}/{}/'.format(system, node.strip('/'))
            batch.append((user, role, fields_data))
            if len(batch) == REDIS_WRITE_BATCH_SIZE:
                IOLoop.current().run_sync(partial(write_role_cache, cache_version, batch))
                time.sleep(REDIS_WRITE_SLEEP_TIME)
                batch = []
    if batch:
        IOLoop.current().run_sync(partial(write_role_cache, cache_version, batch))


def upload_to_redis(non_existing_only, webauth_idm_systems):
    start_time = time.time()
    lock_path = os.path.join(WEBAUTH_IDM_CACHE_FOLDER, 'redis.lock')
    with lock_file(lock_path, blocking=True):
        for _ in range(REDIS_RETRIES):
            try:
                cache_version = IOLoop.current().run_sync(get_cache_version)
                break
            except Exception:
                logger.exception('Could get version number from redis')
        else:
            logger.error('Retries for get_cache_version ended')
            exit(1)

        if non_existing_only and cache_version is not None:
            return
        if non_existing_only:
            logger.info('generate_idm_cache: loading non-existing data to redis')
        else:
            logger.info('generate_idm_cache: loading data to redis')
        cache_version = (int(cache_version) + 1) if cache_version else 0
        for system in webauth_idm_systems:
            for _ in range(REDIS_RETRIES):
                try:
                    upload_system(system, cache_version)
                    break
                except Exception:
                    logger.exception('generate_idm_cache: could not load system %s to redis', system)
            else:
                logger.error('Retries for upload_system ended - %s', system)
                exit(1)
        for _ in range(REDIS_RETRIES):
            try:
                IOLoop.current().run_sync(partial(set_cache_version, cache_version))
                IOLoop.current().run_sync(partial(write_version_timestamp, cache_version))
                if cache_version - 2 >= 0:
                    IOLoop.current().run_sync(partial(remove_version, cache_version - 2))
                break
            except Exception:
                logger.exception('generate_idm_cache: could not update version metadata in redis')
        else:
            logger.error('Retries for setting cache version ended')
            exit(1)

    duration = time.time() - start_time
    logger.info('generate_idm_cache: Writing data to redis took %s seconds', round(duration, 3))


def check_cache_existence(webauth_idm_systems):
    if not os.path.exists(WEBAUTH_IDM_CACHE_FOLDER):
        return False
    for system in webauth_idm_systems:
        if not system:
            continue
        output_path = os.path.join(WEBAUTH_IDM_CACHE_FOLDER, "%s.json" % system)
        if not os.path.exists(output_path):
            return False
        if os.path.getsize(output_path) == 0:
            return False
    return True


def main():
    logger.info('Starting generate_idm_cache process')
    webauth_idm_systems = get_webauth_idm_systems()
    cache_webauth_idm_systems(webauth_idm_systems)

    download_remote_data = True
    non_existing_only = False
    if check_cache_existence(webauth_idm_systems) and len(sys.argv) >= 2:
        if sys.argv[1] == '--try-preloaded-data':
            download_remote_data = False
        elif sys.argv[1] == '--poke-empty-redis':
            download_remote_data = False
            non_existing_only = True
    begin = datetime.utcnow()

    if download_remote_data:
        load_remote_data(webauth_idm_systems)
    logger.info('Uploading roles to redis')
    upload_to_redis(non_existing_only, webauth_idm_systems)

    if download_remote_data:
        logger.info('Updating sync data files')
        # We are only interested in full, non-cached syncs
        end = datetime.utcnow()
        previous_sync_file_name = os.path.join(WEBAUTH_IDM_CACHE_FOLDER, 'previous_sync.json')
        last_sync_file_name = os.path.join(WEBAUTH_IDM_CACHE_FOLDER, 'last_sync.json')
        if os.path.exists(last_sync_file_name):
            os.renames(last_sync_file_name, previous_sync_file_name)
        with open(last_sync_file_name, 'w') as output_file:
            json.dump({'begin': str(begin), 'end': str(end)}, output_file)
    logger.info('Finished generate_idm_cache process')
