import os
import os.path
from os.path import join as pj
import time
import logging

from sandbox.sandboxsdk.process import run_process
from sandbox.sandboxsdk.paths import make_folder
from sandbox.sandboxsdk import network
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

MEM_PER_PROCESS = 2 * 2 << 30
CPU_PER_PROCESS = 2
MIN_HOSTS = 4
MAX_HOSTS = 32  # mostly for debugging


class MapreduceRunner(object):
    def __init__(self, mr_binary, root_dir, client_info, log_path=''):
        self.mr_binary = os.path.abspath(mr_binary)
        self.root_dir = os.path.abspath(root_dir)
        self.tmp_dir = pj(self.root_dir, 'tmp')
        self.client_info = client_info
        self.calc_usage_parameters()
        self.subprocesses = []
        self.server_port = 19013
        self.http_port = 8080
        self.mr_server_string = 'localhost:' + str(self.server_port)
        self.log_path = log_path

    def get_tmp_dir(self):
        return self.tmp_dir

    def get_log_path(self):
        if self.log_path:
            return self.log_path
        else:
            return pj(self.master_dir(), 'server.log')

    def calc_usage_parameters(self):
        ncpu = int(self.client_info['ncpu'])
        physmem = int(self.client_info['physmem'])
        nproc_by_mem = int(float(physmem) / MEM_PER_PROCESS)
        nproc_by_cpu = int(float(ncpu) / CPU_PER_PROCESS)
        self.nproc = min(MAX_HOSTS, max(MIN_HOSTS, min(nproc_by_mem, nproc_by_cpu)))

    def setup_server(self, start=True):
        if start:
            self.create_directores()
            self.write_config()
            self.run_master()
            self.run_hosts()

    def teardown_server(self):
        for p in self.subprocesses:
            try:
                p.terminate()
            except OSError:  # probably died from some other causes, don't mangle the error message by failing in cleanup here
                pass
        alive = self.subprocesses
        time.sleep(5)
        while alive:
            alive = [p for p in alive if p.poll() is None]
            for p in alive:
                p.kill()

    def create_directores(self):
        make_folder(self.root_dir, delete_content=True)
        make_folder(self.tmp_dir)
        make_folder(self.master_dir())
        os.symlink(self.mr_binary, self.master_mr())
        for n in xrange(self.nproc):
            make_folder(self.host_dir(n))
            os.symlink(self.mr_binary, self.host_mr(n))

    def get_environ(self):
        return {}

    def get_cmd_environ(self):
        return {
            'MR_RUNTIME': 'MR',
        }

    def get_cmd_environ_str(self):
        return ' '.join(['='.join(kv) for kv in self.get_cmd_environ().items()])

    def write_config(self):
        with open(pj(self.master_dir(), 'server.cfg'), 'w') as f:
            f.write("""
            <Main>

            ChunkSize = 67108864
            MaxKeySize = 4096
            MaxValueSize = 67100656
            MaxMROpSize = 67108864
            ChunkReplicas = 2

            ChunkSizeToTakeEachSync = 268435456
            ChunksToDeleteEachSync = 256
            MaxChunksPerPartJob = 8
            MaxFragmentRequests = 4
            MaxChunkRequests = 4
            SortBufferSize = 268435456
            ParallelJobs = 2
            FileServerMaxFileSize = 4294967296
            FileServerCacheSize = 4294967296
            FileServerMaxResponses = 64
            FileServerMaxResponsesSize = 1073741824
            AnalyzePerformance = false
            AllowStartWithoutMetaInfo = true
            SnapshotTransactionCount = 256
            SnapshotHistorySize = 5
            MasterStateUpdateIntervalInMilliseconds = 15000
            SchedulerStateUpdateIntervalInMilliseconds = 1000

            DebugHttpWithDirectorySizesEnabled = 1

            NoAtime = 0

            </Main>
            """)
        with open(pj(self.master_dir(), 'quota.cfg'), 'w') as f:
            f.write("""
                <main>
                check = enabled
                defaultDiskSizeQuota = 500G
                defaultChunkCountQuotaFactor = 1M
                defaultTransactionsPerSecond = 0
                cachePeriod = 2
                hiddenSpaceQuota = 100G

                defaultFileServerQuota = 16G
                defaultFileServerPartsQuota = 100
                </main>
                """)

    def run_master(self):
        pwd = os.getcwd()
        try:
            os.chdir(self.master_dir())
            cmd = [self.master_mr(), '-runserver', str(self.server_port), '-http', str(self.http_port), '-log', self.get_log_path()]
            self.master_process = run_process(
                cmd,
                environment={
                    'TMP': self.get_tmp_dir(),
                    'TMPDIR': self.get_tmp_dir(),
                },
                shell=False, wait=False, check=False, log_prefix='mapreduce.master',
            )
            self.subprocesses.append(self.master_process)
        finally:
            os.chdir(pwd)

    def master_failure(self):
        for retry in xrange(3):
            retval = self.master_process.poll()
            if retval is not None:
                if (retval >= 0):
                    msg = " returned %d" % retval
                else:
                    msg = " killed by signal %d" % abs(retval)
                msg = "Mapreduce master failed: "+msg
                logging.warn(msg)
                return msg
            else:
                logging.info("Mapreduce master process %d still running" % (self.master_process.pid))
                time.sleep(5*(retry+1))
        return None

    def run_hosts(self):
        pwd = os.getcwd()
        for n in xrange(self.nproc):
            try:
                os.chdir(self.host_dir(n))
                cmd = [self.host_mr(n), '-runhost', self.mr_server_string, '-http', str(self.host_http_port(n))]
                self.subprocesses.append(
                    run_process(
                        cmd,
                        environment={
                            'SKLAD_DEBUG_PORT': str(self.host_sklad_port(n)),
                            'TMP': self.get_tmp_dir(),
                            'TMPDIR': self.get_tmp_dir(),
                        },
                        shell=False, wait=False, check=False, log_prefix='mapreduce.host'
                    )
                )
                time.sleep(1)
            finally:
                os.chdir(pwd)

    def master_dir(self):
        return pj(self.root_dir, 'master')

    def host_dir(self, n):
        return pj(self.root_dir, 'host%04d' % n)

    def host_mr(self, n):
        return pj(self.host_dir(n), os.path.basename(self.mr_binary))

    def master_mr(self):
        return pj(self.master_dir(), os.path.basename(self.mr_binary))

    def host_http_port(self, n):
        return self.search_for_port(xrange(self.http_port+100*n+1, self.http_port+100*n+50))

    def host_sklad_port(self, n):
        p = self.search_for_port(xrange(self.http_port+100*n+50+1, self.http_port+100*n+100))
        logging.info("Sklad port: %d" % p)
        return p

    def get_proxy_string(self):
        return self.mr_server_string

    def path_mr_client(self):
        return self.mr_binary

    def search_for_port(self, itr):
        for p in itr:
            if (network.is_port_free(p)):
                return p
        raise SandboxTaskFailureError("Could not find free port for MR process")
