from __future__ import print_function, absolute_import

from ..rbtorrent import monkey_patch as skbn_monkey_patch

skbn_monkey_patch()

import argparse
import errno
import logging
import os
import random
import struct
import sys
import time
import yaml

import gevent
import py

from setproctitle import setproctitle

from api.logger import SkynetLoggingHandler

from library.sky.hostresolver import Resolver as BCalcResolver

from ..kernel_util.sys.gettime import monoTime

from ..rbtorrent.utils import Path, SlottedDict
from ..rbtorrent.logger import initialize as initialize_log, JobUidAdapter
from ..rbtorrent.component import Component
from ..rbtorrent.db import Database
from ..rbtorrent.greenish.deblock import Deblock
from ..rpc.server import RPC, Server as RPCServer

from .rpc import SkyboneMDSRPC, SkyboneMDSAdminRPC

from .resource_manager import ResourceManager


PROGNAME = 'skybone-mds'


class SkyboneMDS(Component):
    def __init__(self, app_path, cfg, log, workdir):
        super(SkyboneMDS, self).__init__(logname='')
        self.log = log

        try:
            self.port_offset = int(os.environ.ge('SKYNET_PORTOFFSET', 0))
        except:
            self.port_offset = 0

        self.singletone_port = cfg.rpc_port + self.port_offset
        self.rpc_sock = Path(workdir).join('rpc.sock')

        self.path = app_path
        self.cfg = cfg

        self.workdir = Path(workdir).ensure(dir=1)
        self.active = gevent.event.Event()
        self.stopping = False

        self.ips = None
        self.trackers = None

        self.startup_stage = (time.time(), 'init')
        self.ts_started = monoTime()

    def __repr__(self):
        return '<SkyboneMDS>'

    def _cfg_kmg_convert(self, value):
        if isinstance(value, basestring):
            meth = (float if '.' in value else int)
            if value.endswith('K'):
                value = 1024 * meth(value.rstrip('K'))
            elif value.endswith('M'):
                value = 1024 * 1024 * meth(value.rstrip('M'))
            elif value.endswith('G'):
                value = 1024 * 1024 * 1024 * meth(value.rstrip('G'))

        return value

    def init_network(self):
        from ..rbtorrent.utils import getaddrinfo_g as getaddrinfo
        from _socket import getaddrinfo as _getaddrinfo

        self.ips = {'bb': None, 'bb6': None, 'fb': None, 'fb6': None}

        hostname = gevent.socket.gethostname()

        if '.' not in hostname:
            fqhostname = gevent.socket.getfqdn(hostname)
            if fqhostname != 'any.yandex.ru':
                hostname = fqhostname

        self.log.debug('Detecting our ips (hostname: %s)', hostname)

        try:
            self.log.debug('  resolve %s...' % (hostname, ))
            ips = getaddrinfo(hostname, 0, 0, gevent.socket.SOCK_STREAM, getaddrinfo=_getaddrinfo)
            for family, _, _, _, ipinfo in ips:
                if family == gevent.socket.AF_INET:
                    if ipinfo[0] != '127.0.0.1':
                        self.ips['bb'] = ipinfo[0]
                        self.log.debug('    detected bb: %s', ipinfo[0])
                elif family == gevent.socket.AF_INET6:
                    if ipinfo[0] != '::1':
                        self.ips['bb6'] = ipinfo[0]
                        self.log.debug('    detected bb6: %s', ipinfo[0])
        except gevent.socket.gaierror as ex:
            if ex.errno in (
                getattr(gevent.socket, 'EAI_NONAME', None),  # NOT FOUND
                getattr(gevent.socket, 'EAI_NODATA', None),  # NO ADDRESS ASSOCIATED WITH HOSTNAME (LINUX)
            ):
                self.log.debug('    not found (errno: %d)', ex.errno)
            else:
                # Allow to copier to die completely in case of any
                # dns errors here. Because starting without fastbone
                # support is critical
                raise

        self.hostname = hostname

    def init_tracker_hosts(self):
        from ..rbtorrent.utils import getaddrinfo_g as getaddrinfo
        from _socket import getaddrinfo as _getaddrinfo

        resolver = BCalcResolver()

        self.trackers = {}

        self.log.debug('Initializing tracker list')

        for expr, port in ((self.cfg.coord_hosts, self.cfg.coord_hosts_port), ):
            self.log.debug('  Resolving blinov calc: %s...', expr)
            self.trackers['coord'] = {}
            hostnames = resolver.resolveHosts(expr)
            self.log.debug('    resolved: %s', hostnames)
            self.log.debug('  Resolving ips...')
            ipv4_count, ipv6_count = 0, 0

            families = []

            if self.ips['bb6']:
                families.append(gevent.socket.AF_INET6)

            if self.ips['bb']:
                families.append(gevent.socket.AF_INET)

            if not families:
                families = [gevent.socket.AF_INET6, gevent.socket.AF_INET]

            for hostname in hostnames:
                ips = []

                for family in families:
                    try:
                        for ipinfo in getaddrinfo(
                            hostname, port,
                            family, gevent.socket.SOCK_STREAM, gevent.socket.IPPROTO_IP,
                            getaddrinfo=_getaddrinfo
                        ):
                            ip = ipinfo[4][0]
                            if ipinfo[0] == gevent.socket.AF_INET:
                                ipv4_count += 1
                                ip = '::ffff:%s' % (ip, )
                            else:
                                ipv6_count += 1

                            ipport = (ip, ipinfo[4][1])
                            if ipport not in ips:
                                ips.append(ipport)
                    except gevent.socket.gaierror as ex:
                        if ex.errno in (
                            getattr(gevent.socket, 'EAI_NONAME', None),
                            getattr(gevent.socket, 'EAI_NODATA', None),
                        ):
                            self.log.warning('    Unable to resolve %s:%d: %s, skipping', hostname, port, ex)
                        else:
                            raise

                self.trackers['coord'][hostname] = ips
                # self.trackers['coord'][hostname] = [('::1', 12345), ('::1', 2399)]

            self.log.debug('    resolved: %s', self.trackers['coord'])

    def init_db(self):
        db_path = self.workdir.join('main.db')
        self.log.debug('Opening database at %s', db_path)

        db_temp_dir = self.workdir.join('temp').ensure(dir=1)

        dbobj = Database(
            db_path.strpath,
            mmap=self._cfg_kmg_convert(self.cfg.db.mmap),
            temp=db_temp_dir.strpath,
            parent=self
        )

        self.db_deblock = Deblock(keepalive=None, logger=dbobj.log.getChild('deblock'))
        self.db = self.db_deblock.make_proxy(dbobj, put_deblock='deblock')

        self.startup_stage = (time.time(), 'init_db')
        self.db.open(check=self.db.CHECK_IF_DIRTY, force=True)
        self.db.set_debug(sql=True, transactions=True)

        self.startup_stage = (time.time(), 'migrate_db')
        migrations_path = self.path.join('share', 'db', 'mds')

        fw_migrations = {}
        bg_migrations = {}

        if migrations_path.check(exists=1, dir=1):
            for migration_file in migrations_path.listdir():
                if not (migration_file.basename.endswith('.sql') or migration_file.basename.endswith('.py')):
                    continue
                if migration_file.basename.startswith('fw_'):
                    target = fw_migrations
                elif migration_file.basename.startswith('bw_'):
                    target = bg_migrations
                else:
                    continue

                version = int(migration_file.basename.split('_', 2)[1])
                target[version] = migration_file

        self.db.migrate(fw_migrations, bg_migrations)

        self.startup_stage = (time.time(), 'maintain_db')

        def _maintain(**kwargs):
            try:
                with self.db.deblock.lock('dbmaintain'):
                    self.db.maintain(**kwargs)
            except Exception as ex:
                self.log.critical('Maintainance error: %s', str(ex))

        _maintain(vacuum=86400 * 7, analyze=86400, grow=True)

        def _mnt(t, **kwargs):
            while 1:
                gevent.sleep(t)
                if self.stopping:
                    break
                _maintain(**kwargs)

        # Vacuum: every week, analyze: every day, grow every 10 secs
        gevent.spawn(_mnt, 3600, vacuum=86400 * 7, analyze=86400)
        gevent.spawn(_mnt, 10, grow=True)

        uid_info = self.db.query_one_col('SELECT value_text FROM dbmaintain WHERE key = ?', ['uid'])
        if uid_info:
            uid_version = int(uid_info[0])
            if uid_version == 1 and uid_info[1] == ':':
                self.uid = uid_info[2:]
            else:
                self.uid = None
        else:
            self.uid = None

        if not self.uid:
            self.uid = struct.pack('!Q', random.randint(0, 2 ** 64)).encode('hex')
            self.db.query(
                'REPLACE INTO dbmaintain VALUES (?, ?, null)',
                ['uid', '1:' + self.uid]
            )

    def init_workdir(self):
        return

    def start(self):
        self.log.debug('Starting RPC server serving %r on %s' % (self, self.rpc_sock))

        try:
            self.singletone_sock = gevent.socket.socket(gevent.socket.AF_INET, gevent.socket.SOCK_STREAM)
            self.singletone_sock.setsockopt(gevent.socket.SOL_SOCKET, gevent.socket.SO_REUSEADDR, 1)
            self.singletone_sock.bind(('127.0.0.1', self.singletone_port))
            self.singletone_sock.listen(1)
        except gevent.socket.error as ex:
            if errno.errorcode.get(ex.errno, None) == 'EADDRINUSE':
                error_msg = 'ERROR: address 127.0.0.1:%d already in use! Quitting...' % (self.singletone_port, )
                print(error_msg, file=sys.stderr)
                self.log.critical(error_msg)
                sys.exit(1)
            raise

        try:
            self.rpc_sock.remove()
        except py.error.ENOENT:
            pass

        self.skbn_mds_rpc = SkyboneMDSRPC(self)
        self.skbn_mds_admin_rpc = SkyboneMDSAdminRPC(self)

        self.rpc = RPC(self.log.getChild('rpc'))
        self.rpc_server = RPCServer(self.log, backlog=10, max_conns=1000, unix=self.rpc_sock.strpath)
        self.rpc_server.register_connection_handler(self.rpc.get_connection_handler())

        self.rpc.mount(self.skbn_mds_rpc.ping)
        self.rpc.mount(self.skbn_mds_rpc.add_resource)
        self.rpc.mount(self.skbn_mds_rpc.remove_resource)
        self.rpc.mount(self.skbn_mds_admin_rpc.status)
        self.rpc.mount(self.skbn_mds_admin_rpc.query)
        self.rpc.mount(self.skbn_mds_admin_rpc.dbbackup)
        self.rpc.mount(self.skbn_mds_admin_rpc.evaluate, typ='full')
        self.rpc.mount(self.skbn_mds_admin_rpc.resource_list)
        self.rpc.mount(self.skbn_mds_admin_rpc.resource_remove)

        self.rpc_server.start()

        self.log.info('%r started' % (self, ))

        self.init_network()
        self.init_tracker_hosts()
        self.init_db()
        self.init_workdir()

        self.startup_stage = (time.time(), 'init components')

        self.resource_mngr = ResourceManager(
            self.uid, self.hostname,
            self.workdir, self.db,
            trackers=self.trackers,
            data_port=self.cfg.data_port + self.port_offset,
            announce_port=self.cfg.announce_port + self.port_offset,
            ips=self.ips,
            dfs_mode=self.cfg.dfs_mode,
            dfs_link_pattern=self.cfg.dfs_link_pattern,
            parent=self
        )

        super(SkyboneMDS, self).start()

        self.startup_stage = (time.time(), 'done')
        self.active.set()

        self.log.debug('Allowed incoming connections')

    def wait(self):
        try:
            while True:
                gevent.sleep(120)
        except KeyboardInterrupt:
            self.log.info('Caught KeyboardInterrupt (SIGINT), shutting down...')
            self.active.clear()
            self.stop()
            self.join()

    def stop(self):
        self.stopping = True

        self.db.close()
        self.db.deblock.stop()

        return super(SkyboneMDS, self).stop()


def load_config(filename, fmt):
    cfg = yaml.safe_load(open(filename, 'rb'))

    if fmt == 'skycore':
        cfg = (
            cfg
            ['subsections']['skynet']
            ['subsections']['services']
            ['subsections']['skybone-mds']
            ['config']['config']
        )

    return SlottedDict(cfg)


def determine_app_path():
    fpath = Path(__file__)

    for part in reversed(fpath.parts()):
        if part.join('ctl.py').check(exists=1):
            return part

    raise Exception('Unable to find app path')


def update_proc_title(cfg):
    setproctitle(
        '%s [%d/%d]' % (
            PROGNAME, cfg.rpc_port, cfg.data_port
        )
    )


def _client(sock):
    from ..rpc.client import RPCClient

    cli = RPCClient(sock, None)
    return cli


def main():
    app_path = determine_app_path()

    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', required=True, help='Config file to use')
    parser.add_argument('-cfmt', choices=['ctl', 'skycore'], default='ctl')

    parser.add_argument(
        '-d', '--workdir', required=False, default=os.path.abspath('rbtorrent'),
        help='Working directory'
    )

    parser.add_argument(
        '--skycore-ping', default=None
    )

    args = parser.parse_args()

    if args.skycore_ping:
        sock = args.skycore_ping + '/rpc.sock'

        deadline = time.time() + 120

        while time.time() < deadline:
            try:
                result = _client(sock).call('ping', extended=True).wait()
                print('ping result: %r' % (result, ))
                break
            except Exception as ex:
                print('ping error: %s' % (ex, ))
                time.sleep(5)
        else:
            print('ping timeouts exhaused')
            return 1

        if result['stage'][1] != 'done':
            in_stage = time.time() - result['stage'][0]
            if in_stage < 86400:
                return 0
            else:
                return 1

        active = result.get('active', False)
        if not active:
            return 1

        import socket
        import errno

        port = 7000

        def _check_sock(name, addr, family=socket.AF_INET):
            print('check %s socket' % (name, ))

            try:
                sock = socket.socket(family, socket.SOCK_STREAM)
                sock.settimeout(10)
                sock.connect((addr, port))
            except socket.error as ex:
                print('unable to connect to %s socket: %s, running=False' % (name, ex))
                return False
            except Exception as ex:
                print('unknown error while connecting to %s socket: %s, running=False' % (name, ex))
                return False
            else:
                try:
                    sock.send('a' * (49 + 19))
                    data = sock.recv(1)
                    if data != '':
                        print(
                            '%s socket didnt closed connection (result: %r), running=False' % (
                                name, data
                            )
                        )
                        return False
                except socket.error as ex:
                    if ex.errno != errno.ECONNRESET:
                        print('%s socket not good error: %s, running=False' % (name, ex))
                        return False
                except Exception as ex:
                    print('%s socket unknown error: %s, running=False' % (name, ex))
                    return False

                print('%s socket ok' % (name, ))
                sock.close()

                return True

        if not _check_sock('bb', '127.0.0.1'):
            return 1

        print('(finish) running: True')

        return 0

    # Force kernel.util.sys.user to use green subprocess module
    from ..rbtorrent import subprocess_gevent
    sys.modules['subprocess'] = subprocess_gevent
    for name, mod in sys.modules.iteritems():
        if name.startswith('skybone.kernel_util.sys.user._'):
            if hasattr(mod, 'subprocess'):
                mod.subprocess = subprocess_gevent

    sys.modules['subprocess'] = subprocess_gevent._subprocess

    cfg = load_config(args.config, args.cfmt)

    update_proc_title(cfg)
    initialize_log(SkynetLoggingHandler(app='skybone-mds', filename='skybone-mds.log'))

    root_log = logging.getLogger()
    main_log = JobUidAdapter(root_log.getChild('main'), {})

    main_log.info('Initializing')

    try:
        skbn_mds = SkyboneMDS(app_path, cfg, log=main_log, workdir=args.workdir)
        skbn_mds.start()
        skbn_mds.wait()
    except KeyboardInterrupt:
        main_log.debug('interrupted!')
        print('interrupted!', file=sys.stderr)
        os._exit(1)
    except Exception:
        import traceback
        main_log.debug('Unhandled daemon exception: %s--\nQuit immediately', traceback.format_exc())
        os._exit(1)

    return 0
