import argparse
import gevent.monkey
import json
import logging
import logging.handlers
import socket
import time

import utils

from shardtool import ShardTool


def main(options):
    try:
        ShardTool(
            json.load(open(options.replica_map)),
            options.shard_builder,
            '{}/{}'.format(options.shard_root, options.port),
            options.shard_root,
            '{}:{}'.format(socket.getfqdn(), options.port),
            options.max_dl_speed,
            options.copier_opts,
        ).run(options.port)
    except KeyboardInterrupt:
        pass


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--shard-builder', help='Shard builder executable', required=True)
    parser.add_argument('--replica-map', help='Path to replicamap.json', required=True)
    parser.add_argument('--port', help='Server port', type=int)
    parser.add_argument('--shard-root', help='Root directory for shards (pwd by default)', default='.', required=False)
    parser.add_argument('--debug', help='Debug mode (affects logging)', action='store_true', default=False)
    parser.add_argument('--max-dl-speed', help='Max download speed for skynet (see sky get -H)', type=str, required=False)
    parser.add_argument('--copier-opts', help='Extra options for sky get (sky get --opts ...)', type=str, required=False)

    return parser.parse_args()


def setup_logging(debug):
    utils.ensure_dir('./logs')

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)

    handler = logging.handlers.RotatingFileHandler('./logs/debug.log', maxBytes=1024**3, backupCount=10)
    handler.setLevel(logging.DEBUG)
    handler.setFormatter(SkynetishFormatter())

    logger.addHandler(handler)

    handler = logging.handlers.RotatingFileHandler('./logs/info.log', maxBytes=1024**3, backupCount=10)
    handler.setLevel(logging.INFO)
    handler.setFormatter(SkynetishFormatter())

    logger.addHandler(handler)

    if debug:
        handler = logging.StreamHandler()
        handler.setLevel(logging.DEBUG)
        handler.setFormatter(SkynetishFormatter())
        logger.addHandler(handler)

    logging.getLogger('requests').setLevel(logging.CRITICAL)
    logging.getLogger('httpserver').setLevel(logging.CRITICAL)


class SkynetishFormatter(logging.Formatter):
    def __init__(self):
        super(SkynetishFormatter, self).__init__()

    def formatTime(self, record, **kwargs):
        t = time.strftime('%Y-%m-%d %H:%M:%S', self.converter(record.created))
        return '%s.%03d' % (t, record.msecs)

    def format(self, record):
        levelno = record.levelno
        if levelno > 5:
            level = '[%-4s]' % logging.getLevelName(levelno)
        else:
            level = '(%s)' % (str(levelno) if levelno < 0 else ' %d' % levelno)

        date = self.formatTime(record)
        message = record.getMessage()
        if record.exc_info:
            message += '\n' + self.formatException(record.exc_info)
        header = '{0} {1} [{2}]  '.format(date, level, record.name)

        if '\n' in message:
            # special case for multi-line log messages
            message_lines = message.splitlines()
            line = [header + message_lines[0]]
            prepend = '%s%s' % (' ' * (len(header) - 2), ': ')
            line.extend(['%s%s' % (prepend, l) for l in message_lines[1:]])
            line = '\n'.join(line)
        else:
            line = '{header}{message}'.format(header=header, message=message)
        return line


if __name__ == '__main__':
    gevent.monkey.patch_all(dns=False)
    args = parse_args()
    setup_logging(args.debug)
    main(args)
