from __future__ import division

import json
import logging

from library.python.protobuf.json.proto2json import (
    FldNameMode,
    proto2json,
    Proto2JsonConfig,
)
import six
from six.moves import http_client
import yt.wrapper as yt


DEFAULT_TIMEOUT = 10.

DEFAULT_MAX_RPS_PER_JOB = 1000
DEFAULT_MAX_ROWS_PER_JOB = 250000
DEFAULT_MAX_FAILS = 3
DEFAULT_FAILED_REQUEST_QUEUE_SIZE = 5

logger = logging.getLogger(__name__)


class NoReachableHostsError(Exception):
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return repr(self.value)


class PingMapper(object):
    def __init__(self, timeout):
        self.timeout = timeout

    def __call__(self, row):
        row["reachable"] = 1

        try:
            connection = http_client.HTTPConnection(row["host"], timeout=self.timeout)
            connection.connect()
            connection.close()
        except:
            row["reachable"] = 0

        yield row


def check_reachability(hosts, tmp_dir, timeout=1.):
    expiration_timeout_in_ms = 2 * 24 * 3600 * 1000
    yt.mkdir(tmp_dir, recursive=True)

    hosts_table = yt.create_temp_table(path=tmp_dir, prefix="hosts-", expiration_timeout=expiration_timeout_in_ms)
    result_table = yt.create_temp_table(path=tmp_dir, prefix="hosts-reachability-", expiration_timeout=expiration_timeout_in_ms)

    yt.write_table(hosts_table, [{"host": host} for host in hosts], raw=False)
    yt.run_map(PingMapper(timeout), hosts_table, result_table)

    reachability = yt.read_table(result_table)
    reachable_hosts = []
    unreachable_hosts = []
    for row in reachability:
        if row["reachable"]:
            reachable_hosts.append(row["host"])
        else:
            unreachable_hosts.append(row["host"])

    yt.remove(hosts_table, force=True)
    yt.remove(result_table, force=True)

    return reachable_hosts, unreachable_hosts


def check_and_log_host_reachability(hosts, yt_tmp_dir):
    logger.info("Destination hosts: %s", hosts)
    logger.info("Checking reachability from YT...")

    reachable_hosts, unreachable_hosts = check_reachability(hosts, yt_tmp_dir)
    if unreachable_hosts:
        logger.warn("Hosts that are unreachable from YT: %s", unreachable_hosts)
    if not reachable_hosts:
        raise NoReachableHostsError("No hosts are reachable.")

    logger.info("Reachable hosts: %s", reachable_hosts)
    return reachable_hosts


def add_mapper_args(parser, url_template, max_rows_per_job=DEFAULT_MAX_ROWS_PER_JOB,
                    max_rps_per_job=DEFAULT_MAX_RPS_PER_JOB, timeout=DEFAULT_TIMEOUT,
                    failed_request_queue_size=DEFAULT_FAILED_REQUEST_QUEUE_SIZE):
    parser.add_argument("--job-count", default=-1, help="Recommended job count. If -1, then derived from --rows-per-job", type=int)
    parser.add_argument("--max-concurrent-jobs", default=-1, help="Max concurrent job count. If -1, then derived from --max-rps", type=int)
    parser.add_argument("--url-template", default=url_template, help="URL template for sending graphs")
    parser.add_argument("--threads", required=True, help="Threads per job that send requests.", type=int)
    parser.add_argument("--timeout", default=timeout, help="Timeout for requests, seconds", type=float)
    parser.add_argument("--max-rps", required=True, help="Max RPS", type=int)
    parser.add_argument("--max-fails", default=DEFAULT_MAX_FAILS,
                        help="If server did not respond to requests at least once this number of times, it is excluded "
                             "from the list of reachable hosts",
                        type=int)
    parser.add_argument("--max-rows-per-job", default=max_rows_per_job, help="Max number of rows per job", type=int)
    parser.add_argument("--max-rps-per-job", default=max_rps_per_job, help="Estimate for maximum RPS of single job", type=int)
    parser.add_argument("--failed-request-queue-size", default=failed_request_queue_size,
                        help="Number of failed requests remembered per status code. "
                             "Info about failed requests is output to stderr.", type=int)


def normalize_arg_name(name):
    return name.replace("-", "_")


class JobConfig(object):
    def __init__(self, additional_args=None):
        self.job_count = -1
        self.max_concurrent_jobs = -1
        self.url_template = None
        self.threads = 1
        self.timeout = DEFAULT_TIMEOUT
        self.max_rps = 1000
        self.max_fails = DEFAULT_MAX_FAILS
        self.max_rows_per_job = DEFAULT_MAX_ROWS_PER_JOB
        self.max_rps_per_job = DEFAULT_MAX_RPS_PER_JOB
        self.failed_request_queue_size = DEFAULT_FAILED_REQUEST_QUEUE_SIZE
        self.max_failed_job_count = None
        self.memory_limit = None
        self.batch_size = 1

        if isinstance(additional_args, list):
            self.additional_args = dict(zip((normalize_arg_name(x) for x in additional_args), [None]*len(additional_args)))
        else:
            self.additional_args = {}

    def update(self, args):
        dictionary = args if isinstance(args, dict) else args.__dict__
        dictionary = {normalize_arg_name(k): v for k, v in six.iteritems(dictionary)}

        for k in dictionary:
            if k not in self.__dict__ and k not in self.additional_args:
                raise KeyError("Unknown key in job config: {}".format(k))

        filtered_args = {k: v for k, v in six.iteritems(dictionary) if k in self.__dict__}
        self.__dict__.update(filtered_args)

        additional_args = {k: v for k, v in six.iteritems(dictionary) if k in self.additional_args}
        self.additional_args.update(additional_args)
        return self


def job_config_from_args(args, additional_args=None):
    return JobConfig(additional_args).update(args)


def job_config_from_proto(proto, additional_args=None):
    additional_args = additional_args or {}
    config = Proto2JsonConfig(
        field_name_mode=FldNameMode.FieldNameSnakeCase,
        map_as_object=True,
    )
    data = json.loads(proto2json(proto, config))
    data.update(additional_args)
    return JobConfig(
        additional_args=list(additional_args.keys()) + ["tvm_src_id", "tvm_dst_id", "max_retries"]
    ).update(data)


def tvm_job_config_from_args(args, additional_args=None):
    additional_args = additional_args or []
    additional_args += ["tvm_src_id", "tvm_dst_id"]
    return job_config_from_args(args, additional_args)


def get_job_settings(job_config, table, yt_client=yt):
    max_concurrent_jobs = job_config.max_concurrent_jobs
    if max_concurrent_jobs == -1:
        max_concurrent_jobs = 1 + ((job_config.max_rps - 1) // job_config.max_rps_per_job)

    job_count = job_config.job_count
    if job_count == -1:
        row_count = yt_client.get_attribute(table, "row_count")
        job_count = 1 + ((row_count-1) // job_config.max_rows_per_job)
    job_count = max(job_count, max_concurrent_jobs)

    rps_limit = int(round(float(job_config.max_rps) / max_concurrent_jobs))

    return job_count, max_concurrent_jobs, rps_limit


def build_mapper(mapper_constructor, hosts, job_config, rps_limit, is_https=False):
    return mapper_constructor(
        hosts=hosts,
        timeout=job_config.timeout,
        url=job_config.url_template,
        threads=job_config.threads,
        rps_limit=rps_limit,
        max_fails=job_config.max_fails,
        failed_request_queue_size=job_config.failed_request_queue_size,
        is_https=is_https,
        batch_size=job_config.batch_size,
        **job_config.additional_args
    )


def run_map(mapper_constructor, hosts, job_config, src_table, dst_table=None, additional_spec=None, yt_client=yt, is_https=False):
    job_count, max_concurrent_jobs, rps_limit = get_job_settings(job_config, src_table, yt_client=yt_client)
    logger.info("Requested job count: total %d, concurrent %d", job_count, max_concurrent_jobs)

    mapper = build_mapper(mapper_constructor, hosts, job_config, rps_limit, is_https=is_https)
    spec = {"resource_limits": {"user_slots": max_concurrent_jobs}}
    if job_config.max_failed_job_count is not None:
        spec["max_failed_job_count"] = job_config.max_failed_job_count
    if job_config.memory_limit is not None:
        spec["mapper"] = {"memory_limit": job_config.memory_limit}

    if additional_spec:
        spec.update(additional_spec)
    yt_client.run_map(mapper, src_table, dst_table, job_count=job_count, spec=spec)
