#!/usr/bin/env python2.7
import os
import argparse
import fcntl
import random
import time
import json
import logging

SIGALARM_SEC = 3600                                         # Max execution time

TIMEOUT_DISABLE = 7200

F_LOCK_FR = '/var/run/salt/lock_file'                       # Lock file for run once
FILE_LOCK_RUN = '/var/run/salt/disabled'                    # Flag file for disable run
BATCH_LOG = '/var/log/salt-batch.log'                       # Batch file for result (result salt json)
LOG_FILE_DEFAULT = '/var/log/ya-salt.log'

_GLOBAL_SALT_CALLER = None
_GLOBAL_SALT_OPTS = None

logger = logging.getLogger('ya-salt')


def init_logger(level, log_file_path=None):
    global logger
    logger.setLevel(level)
    ch = logging.StreamHandler()
    ch.setLevel(level)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    if log_file_path:
        fileHandler = logging.FileHandler(log_file_path)
        fileHandler.setFormatter(formatter)
        logger.addHandler(fileHandler)

    ch.setFormatter(formatter)
    # add the handlers to the logger
    logger.addHandler(ch)

    if log_file_path:
        logger.info('Logging to file {} enabled'.format(log_file_path))


def _get_salt_caller(singletone=True):
    global _GLOBAL_SALT_OPTS, _GLOBAL_SALT_CALLER
    if singletone:
        if _GLOBAL_SALT_CALLER is not None:
            logger.debug('Return caller from cache.')
            return _GLOBAL_SALT_CALLER

    import salt.state
    logger.debug('Generate new caller and opts (for salt) ...')
    _GLOBAL_SALT_OPTS = salt.config.minion_config('/etc/salt/minion')
    _GLOBAL_SALT_CALLER = salt.client.Caller(mopts=_GLOBAL_SALT_OPTS)
    return _GLOBAL_SALT_CALLER


def _get_salt_opts():
    if _GLOBAL_SALT_OPTS is None:
        _get_salt_caller(singletone=False)
    return _GLOBAL_SALT_OPTS


def salt_run(cmd):
    caller = _get_salt_caller()
    res_cmd = caller.cmd(*cmd)
    return res_cmd


def salt_run_highstate(highstates=[]):
    import salt.state
    hs = salt.state.HighState(_get_salt_opts(), pillar=salt_pillar_items(), 
            pillar_enc=None)

    hs.push_active()
    matches = hs.top_matches(hs.get_top())
    if highstates:
        matches = {_get_salt_opts().get('environment', 'None'): highstates}
    logger.debug('Matches for render templates : {}'.format(matches))
    try:
        # high_ = st_.render_highstate(st_.top_matches(st_.get_top()))
        logger.debug('Starting compile_highstate ... ')
        st_time_compile = time.time()
        # high_state_compile = hs.compile_highstate()
        high_state_compile = hs.render_highstate(matches)
        logger.debug('Highstate compile completed {0:.2f} sec executed.'.format(
            time.time()-st_time_compile
        ))
    except Exception as e:
        raise Exception('Raise exeption compile highstate {}'.format(e))

    hs.pop_active()
    if high_state_compile[1]:
        logger.error('Highstate compile is error')
        for mess in high_state_compile[1]:
            logger.error(mess)

    st_time_call_hs = time.time()
    logger.debug('Starting call highstates ... ')
    hs.load_dynamic(matches)
    cumask = os.umask(0o77)
    os.umask(cumask)
    result = hs.state.call_high(high_state_compile[0])
    logger.debug('Called highstates complited, executed time {0:.2f} sec.'.format(
        time.time()-st_time_call_hs
    ))
    return result


def salt_pillar_items():
    return _get_salt_opts()['pillar']


def salt_pillar_get(pillar_name):
    return _get_salt_opts().get('pillar', {}).get(pillar_name, None)


def salt_grains_items():
    return _get_salt_opts().get('grains', {})


def salt_run_all_state():
    return salt_run_highstate()


def salt_run_custom_state(state):
    return salt_run_highstate(highstates=[state])


def lock_run(path):
    fd = None
    try:
        fd = os.open(path, os.O_CREAT)
        fcntl.flock(fd, fcntl.LOCK_NB | fcntl.LOCK_EX)
        return True
    except (OSError, IOError):
        if fd:
            os.close(fd)
        return False


def is_allowed_run():
    if not os.path.exists(FILE_LOCK_RUN):
        return True
    if os.path.getmtime(FILE_LOCK_RUN)+TIMEOUT_DISABLE > time.time():
        return False
    else:
        return True


def cmd_disable(args):
    if not is_allowed_run():
        logger.warn("Salt already disabled, update lock time.")
    else:
        logger.info("Salt disabled.")

    with open(FILE_LOCK_RUN, "w") as f:
        f.write(str(time.time()))
    return True


def summary_stats_result(result):
    stat = {
        'ok': 0,
        'fail': 0,
        'changed': 0,
        'fatal': 0
    }

    if isinstance(result, dict):
        for name, data in result.iteritems():
            if 'result' in data:
                if data['result']:
                    stat['ok'] += 1
                else:
                    stat['fail'] += 1
            if 'changes' in data:
                if len(data['changes']) > 0:
                    stat['changed'] += 1
    elif isinstance(result, list):
        # This is fatal
        stat['fatal'] = 1
    return stat


def cmd_enable(args):
    if is_allowed_run():
        logger.warning("Salt already enabled.")
        return False
    os.unlink(FILE_LOCK_RUN)
    logger.info("Salt enabled.")
    return True


def cmd_status(args):
    if args.status_disable:
        if is_allowed_run():
            logger.info("Status salt : Enabled")
            return True
        logger.info("Status salt : Disabled")
        return False
    return False


def cmd_run(args):
    # Check run_ocne
    if not lock_run(F_LOCK_FR):
        logger.fatal("{} is locked.".format(F_LOCK_FR))
        return False

    if not args.ignore_run_once:
        if not is_allowed_run():
            logger.fatal("Salt disabled, if u want ignore use --ignore-run-once flag".format(FILE_LOCK_RUN))
            return False

    if args.random_sleep:
        logger.info("Random sleeping {} ... ".format(args.random_sleep))
        time.sleep(random.randint(1, args.random_sleep))

    # choise state
    if args.state is None:
        # salt-call state.apply
        logger.info("Execute all state ...")
        res_exe = salt_run_all_state()
        if res_exe is False:
            return False
    else:
        # salt-call state.apply YOUR_STATE
        logger.info('Execute custom state "{}" ...'.format(args.state))
        res_exe = salt_run_custom_state(args.state)

    sm_stat = summary_stats_result(res_exe)

    with open(BATCH_LOG, 'a') as log:
        log.write(json.dumps(res_exe)+"\n")
        log.close()

    if sm_stat['fatal'] > 0:
        logger.fatal("Fatal error. Please see {}".format(BATCH_LOG))
        return False
    return True


def main():
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(help='List command for ya-salt')

    # Global params
    parser.add_argument('--debug', action="store_true", help='Debug logging', default=False)
    parser.add_argument('--log-file', action="store_true", help='Write log to {}'.format(
        LOG_FILE_DEFAULT), default=False)

    disable_parser = subparsers.add_parser('disable', help='Disable auto salt-call {} sec'.format(TIMEOUT_DISABLE))
    enable_parser = subparsers.add_parser('enable', help='Enable auto salt-call')
    status_parser = subparsers.add_parser('status', help='Status')

    run_parser = subparsers.add_parser('run', help='Run salt-call')

    run_parser.add_argument('--state', help='Name state', default=None)
    run_parser.add_argument('--random-sleep', type=int, help='Random sleep before start', default=None)
    run_parser.add_argument('--ignore-run-once', action="store_true", help='For ignore disabled flag', default=False)

    status_parser.add_argument('--status-disable', action="store_true", help='Return status disable', default=False)

    disable_parser.set_defaults(handle=cmd_disable)
    enable_parser.set_defaults(handle=cmd_enable)
    status_parser.set_defaults(handle=cmd_status)
    run_parser.set_defaults(handle=cmd_run)

    args = parser.parse_args()

    log_level = logging.INFO
    if args.debug:
        log_level = logging.DEBUG

    log_file_path = None
    if args.log_file:
        log_file_path = LOG_FILE_DEFAULT

    init_logger(log_level, log_file_path)

    return args.handle(args)


if __name__ == "__main__":
    import sys
    import signal

    signal.alarm(SIGALARM_SEC)
    if main():
        logger.debug('Success')
        logger.info('-'*40)
        sys.exit(0)
    logger.error('Fatal')
    logger.info('-'*40)
