#!/usr/bin/env python

import argparse
import json
import logging
import logging.handlers
import os
import signal
import socket
import struct
import sys
import time

from datetime import datetime
from pprint import pformat
from tempfile import NamedTemporaryFile

from agent import Component, util
from modules.abstract import PlainModule


def snapshot_version():
    return 'snapshot'


try:
    from svnversion import version as get_svn_version
except ImportError:
    get_svn_version = snapshot_version

LOG = logging.getLogger(__name__)

AGENT_VERSION = '{major}.{minor}.{revision}.{build}'.format(
    major=3,
    minor=0,
    revision=get_svn_version(),
    build=0,
)


def _deep_update_dict(target_dict, source_dict):
    for k, v in source_dict.iteritems():
        if isinstance(v, dict):
            _deep_update_dict(target_dict.setdefault(k, {}), v)
        else:
            target_dict[k] = source_dict[k]


def read_configuration_file(filepath):
    with open(filepath, 'r') as f:
        config = json.load(f, 'UTF-8')

    # todo: add better configuration file validation
    assert type(config) == dict, 'invalid configuration file format'

    return config


def dump_states(config_path):
    cfg = read_configuration_file(config_path)

    try:
        socket_path = cfg['components']['socketreporter']['config']['rpc_path']
    except(KeyError, TypeError):
        raise RuntimeError("state dumper is not configured for this host")

    try:
        sock = socket.socket(socket.AF_INET6 if isinstance(socket_path, int) else socket.AF_UNIX)
        sock.settimeout(120)
        sock.connect(('::1', socket_path) if isinstance(socket_path, int) else socket_path)
        sock.sendall('\x02')
        datalen = sock.recv(4)
        if len(datalen) < 4:
            raise RuntimeError("unexpected connection error")

        remainder, = struct.unpack(">I", datalen)
        data = ''
        while remainder:
            chunk = sock.recv(remainder)
            if not len(chunk):
                raise RuntimeError("server unexpectedly closed connection")

            data += chunk
            remainder -= len(chunk)

        sock.close()

        print(data)  # we're not analyzing data here, may be it's wrong
    except socket.error as e:
        raise RuntimeError("cannot connect to state dumper: %s" % (e,))


class Agent(object):
    PARENT_PID_CHECK_INTERVAL_SECS = 5

    def __init__(self, args):
        self.args = args
        self.state = {'version': AGENT_VERSION, 'start_time': datetime.utcnow()}
        self.context_config = {}
        self.running = None
        self.context = {
            'core': self,
            'components': {},
            'modules': {}
        }
        self.ppid = os.getppid()
        LOG.info("ppid=%d", self.ppid)
        self.next_ppid_check_time = time.time() + self.PARENT_PID_CHECK_INTERVAL_SECS

    def __repr__(self):
        return "oops-agent version: %s" % self.state['version']

    def check_ppid_changes(self):
        if time.time() < self.next_ppid_check_time:
            return

        if os.getppid() != self.ppid:
            LOG.error("ppid changed: %d -> %d, quit agent", self.ppid, os.getppid())
            self.running = False
            return

        self.next_ppid_check_time = time.time() + self.PARENT_PID_CHECK_INTERVAL_SECS

    def start(self):
        self.running = True

        LOG.info('register signal handlers')
        signal.signal(signal.SIGUSR1, self.reload)
        signal.signal(signal.SIGUSR2, self.save_state)
        signal.signal(signal.SIGTERM, self.sigterm)

        pidfile_path = self.args.pidfile
        if pidfile_path:
            pidfile = util.PidLockFile(pidfile_path)
            running_pid = pidfile.is_running()
            if running_pid:
                LOG.error('agent is already running, pid %i!', running_pid)
                print('agent is running!')
                raise Exception('agent is already running, pid %i!' % running_pid)

            with pidfile:
                if self.args.daemonize:
                    util.daemonize()
                    pidfile.write_pid()

                self._startup()
        else:
            self._startup()

    def _startup(self):
        LOG.info('init')
        self.prepare_config()
        self.load_components_and_modules()

        LOG.info('start components')
        for name, component in self.context['components'].iteritems():
            LOG.info('starting %s', name)
            if isinstance(component, Component):
                component.start()

        LOG.info('start modules')
        for name, module in self.context['modules'].iteritems():
            LOG.info('starting %s', name)
            module.start()

    def _shutdown(self):
        modules = self.context['modules']
        for modname, module in modules.items():
            LOG.info('stopping module %s', modname)
            try:
                module.stop()
            except:
                LOG.exception('error stopping module %s', modname)
            finally:
                del modules[modname]

        components = self.context['components']
        for name, component in components.iteritems():
            LOG.info('stopping component %s', name)
            try:
                component.stop()
            except:
                LOG.exception('error stopping component %s', name)

        for name, component in components.items():
            LOG.info('stopping component %s', name)
            try:
                if component.is_alive:
                    component.join()
            except:
                LOG.exception('error joining component %s', name)
            finally:
                del components[name]

    def prepare_config(self):
        config_path = self.args.config
        LOG.info('read configuration file %s', config_path)
        parsed_config = read_configuration_file(config_path)
        _deep_update_dict(self.context_config, parsed_config)

        self.context['hostnames'] = [self.args.hostname] if self.args.hostname else util.get_my_hostnames()
        self.context['config'] = self.context_config
        LOG.debug('context: %s' % self.context)

    def _load_components(self):
        LOG.info('Load components')
        agent_lib = __import__('agent', globals(), locals(), [], -1)
        for name, info in self.context_config['components'].iteritems():
            if info.get('disabled'):
                LOG.info('component %s is disabled in config', name)
                continue
            try:
                LOG.info('loading %s', name)
                assert info['class'] in agent_lib.__dict__, "class %s not found" % info['class']
                instance = agent_lib.__dict__[info['class']](self.context, config=info['config'])
                self.context['components'][name] = instance
            except BaseException:
                LOG.exception('error loading component %s', name)

    def _load_modules(self):
        LOG.info('Load modules')
        for name, config in sorted(self.context_config['modules'].iteritems(), key=lambda x: x[0]):
            if config.get('disabled'):
                LOG.info('module %s is disabled in config', name)
                continue
            try:
                LOG.info('loading %s', name)
                lib_ = __import__('modules.%s' % name, globals(), locals(), ['AgentModule'], -1)
                instance = lib_.AgentModule(arch=sys.platform, config=config)
                assert isinstance(instance, PlainModule)
                self.context['modules'][name] = instance
            except BaseException:
                LOG.exception('error loading module %s', name)

    def _init_components(self):
        LOG.info('Init components')
        for name, component in self.context['components'].iteritems():
            LOG.debug('init %s', name)
            if isinstance(component, Component):
                component.init()

    def _init_modules(self):
        LOG.info('Init modules')
        for name, module in sorted(self.context['modules'].iteritems(), key=lambda x: x[0]):
            LOG.debug('init %s', name)
            if hasattr(module, 'context'):
                module.context = self.context
            module.init()

    def load_components_and_modules(self):
        self._load_components()
        self._load_modules()

        self._init_components()
        self._init_modules()

    def save_state(self, signal_number, interrupted_stack_frame):
        dp = '%d.%m.%y %H:%M:%S'
        with NamedTemporaryFile(mode='w+', suffix='-current-state', prefix='-oops-agent', delete=False) as f:
            f.write('All times is UTC\n')
            f.write(datetime.utcnow().strftime(dp) + '\n')
            f.write('Agent version: %s\n' % self.state['version'])
            f.write('Start time: %s\n' % self.state['start_time'])
            f.write('\nContext:\n')
            f.write('=' * 30 + '\n')
            f.write(pformat(self.context, indent=2))
            f.write('\nComponents:\n')
            f.write('=' * 30 + '\n')
            for name, component in self.context['components'].iteritems():
                f.write('= %s\n' % name)
                f.write('=' * 30 + '\n')
                f.write(component.get_state())
                f.write('\n\n')

            f.flush()

            LOG.info("Current state saved by request to '%s'", f.name)
            return f.name

    def reload(self, signal_number, interrupted_stack_frame):
        LOG.info('Reloading')
        self._shutdown()
        self._startup()

    def sigterm(self, signal_number, interrupted_stack_frame):
        LOG.info('got SIGTERM, stopping')
        self.stop()

    def stop(self):
        if self.running:
            self.running = False
            self._shutdown()

        LOG.info('exiting')

    def dump_module_and_exit(self, module_name):
        self.prepare_config()
        self._load_modules()
        self._init_modules()

        module = self.context['modules'].get(module_name)
        if module is None:
            print >> sys.stderr, "ERROR: module is disabled or not found: {}".format(module_name)
            return 1

        print json.dumps(module.get_value(), indent=4)
        return 0


def main():
    parser = argparse.ArgumentParser(prog='OopsAgent')

    parser.add_argument('-d', dest='daemonize', action='store_true', default=False, help='daemon mode')
    parser.add_argument('config', help='configuration file')

    parser.add_argument('--hostname', dest='hostname', default=None, help='my hostname')
    parser.add_argument('-f', '--force', dest='force', action='store_true')

    parser.add_argument('--pidfile', default=None, help='path to pid file')
    parser.add_argument('--logfile', default=None, help='path to log file or directory where logs will be stored')

    parser.add_argument('--version', action='version', version='%(prog)s ' + AGENT_VERSION)

    parser.add_argument('--dump-states', action='store_true', default=False,
                        help='dump state of the running agent modules and exit')

    parser.add_argument("--dump-module", default=None, help="dump specified modules and exit")

    args = parser.parse_args()

    path = args.logfile
    if path is None:
        logging.basicConfig(level='DEBUG' if not args.dump_states else 'CRITICAL')
    else:
        if os.path.isdir(path):
            path = os.path.join(path, 'oops-agent.log')

        logger = logging.getLogger()

        handler = logging.handlers.RotatingFileHandler(filename=path,
                                                       maxBytes=2 * 1024 * 1024,
                                                       backupCount=2)
        handler.setLevel(logging.DEBUG)
        handler.setFormatter(logging.Formatter('%(asctime)s [%(name)-8s] [%(levelname)s]: %(message)s'))
        logger.addHandler(handler)

        handler_err = logging.handlers.RotatingFileHandler(filename='%s-errors%s' % os.path.splitext(path),
                                                           maxBytes=2 * 1024 * 1024,
                                                           backupCount=2)
        handler_err.setLevel(logging.ERROR)
        handler_err.setFormatter(logging.Formatter('%(asctime)s [%(name)-8s] [%(levelname)s]: %(message)s'))
        logger.addHandler(handler_err)

        logging.getLogger().setLevel(logging.DEBUG)

    if not os.path.isfile(args.config):
        LOG.error('Configuration file %s not found', args.config)
        return -1

    if args.dump_states:
        dump_states(args.config)
        return 0

    LOG.info('Start agent')

    agent = Agent(args)

    if args.dump_module is not None:
        return agent.dump_module_and_exit(args.dump_module)

    try:
        agent.start()
    except KeyboardInterrupt:
        return 0
    except:
        LOG.exception('cannot start agent')
        return -2

    try:
        LOG.info('sleep circle')
        while agent.running:
            agent.check_ppid_changes()
            time.sleep(0.1)
    except KeyboardInterrupt:
        return 0
    except:
        LOG.exception('agent loop failed')
        return -3
    finally:
        agent.stop()

    return 0


if __name__ == '__main__':
    sys.exit(main())
