import base64
import email
import json
import logging
import multiprocessing
import os

from six.moves.urllib.parse import (
    urljoin,
    urlparse,
)


logger = logging.getLogger(__name__)
logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING)


END_OF_QUEUE_SENTINEL = "END_OF_QUEUE"
STAT_HANDLERS = [
    "content",
    "count",
    "get_data",
    "metacount",
    "pmatch",
    "ppstate",
    "rank",
    "stat_dict",
    "write_wide",
]


def create_session():
    import requests

    retry = requests.packages.urllib3.Retry(
        total=3,
        backoff_factor=1,
        status_forcelist=[429, 500, 502, 503, 504, 521],
    )
    session = requests.session()
    session.mount('http://', requests.adapters.HTTPAdapter(max_retries=retry))
    return session


def get_headers(headers_data):
    headers_str = headers_data
    if isinstance(headers_data, dict):
        encoding = headers_data.get("$encoding")
        if encoding == "base64":
            headers_str = base64.b64decode(headers_data["$data"])
        else:
            raise ValueError("Unknown headers encoding: {}".format(encoding))
    headers = email.message_from_string(headers_str)
    return dict(headers)


def get_response_body(ext):
    response = ext["response"]["http"].get("entity", "")
    if isinstance(response, str):
        return response
    elif isinstance(response, dict):
        encoding = response.get("$encoding")
        if encoding == "base64":
            return base64.b64decode(response["$data"])
        raise ValueError("Unknown response encoding: {}".format(encoding))


def get_all_exts(exts_data):
    result = []
    for item in exts_data:
        nested_exts = item.get("response", {}).get("exts")
        if nested_exts:
            result += get_all_exts(nested_exts)
        result.append({
            "ind": item["ind"],
            "headers": item["headers"],
            "request": item["request"],
            "response": item.get("response", {}),
            "success": item["success"],
            "tag": item["tag"],
        })
    return result


def get_exts(path, out_queue, allowed_tags=None):
    with open(path, 'r') as f:
        try:
            data = json.load(f)
        except ValueError:
            logger.warning("Failed to load %s", path)
            return

    request_id = str(int(data["request_id"]))

    exts = get_all_exts(data["exts"])
    for item in exts:
        try:
            status_code = item["response"]["http"]["code"]
        except KeyError:
            logger.debug("Failed to get response status code, request_id: %s, tag: %s", request_id, item["tag"])
            continue

        ok = item["success"] and item["response"]["http"]["code"] < 400
        if not ok:
            logger.debug(
                "Response is not ok, request_id: %s, tag: %s, status code: %s, success: %s",
                request_id, item["tag"], status_code, item["success"],
            )
            continue

        if allowed_tags and item["tag"] not in allowed_tags:
            continue

        try:
            response_body = get_response_body(item)
        except ValueError as e:
            logger.debug("Failed to get response body, request_id: %s, tag: %s: %s", request_id, item["tag"], e)
            continue
        if not response_body:
            continue

        try:
            request_headers = get_headers(item["headers"])
        except ValueError as e:
            logger.debug("Failed to get headers, request_id: %s, tag: %s: %s", request_id, item["tag"], e)
            continue

        out_queue.put({
            "request_id": request_id,
            "tag": item["tag"],
            "shard": item["ind"],
            "query": item["request"]["http"]["query"],
            "request_headers": request_headers,
            "response_body": response_body,
        })


def get_cachedaemon_port(cachedaemon, ext):
    instance_type = ext["tag"]
    if ext["tag"] in STAT_HANDLERS:
        instance_type = "yabstat{:02d}".format(int(ext["shard"]) + 1)
    try:
        return cachedaemon.get_ports_by_tag()[instance_type]
    except KeyError:
        logger.error("No cachedaemon instance for service \"%s\"", instance_type)


def put_ext_to_cachedaemon(exts_queue, cachedaemon):
    import requests

    session = create_session()

    while True:
        ext = exts_queue.get()
        if ext == END_OF_QUEUE_SENTINEL:
            logger.debug("%s found", END_OF_QUEUE_SENTINEL)
            return

        port = get_cachedaemon_port(cachedaemon, ext)
        if port is None:
            logger.debug(
                "Port not found: request_id: %s, tag: %s, shard: %s",
                ext["request_id"], ext["tag"], ext["shard"])
            continue

        headers = {
            "Content-Type": "application/octet-stream",
            "X-YaBS-Request-Id": ext["request_id"],
            # Explicitly disable parsing request as raw HTTP request https://st.yandex-team.ru/BSSERVER-23065
            "x-yabs-cachedaemon-payload-is-http-request": "0",
        }
        try:
            headers["x-yabs-ext-request-key"] = ext["request_headers"]["x-yabs-ext-request-key"]
        except KeyError:
            pass

        path = urlparse(ext["query"]).path
        if ext["tag"] in STAT_HANDLERS:
            path = ext["tag"] + path
        url = urljoin("http://localhost:{}".format(port), path)
        request = requests.Request("PUT", url, headers=headers, data=ext["response_body"])
        prepared_request = session.prepare_request(request)

        response = session.send(prepared_request, timeout=5)
        try:
            response.raise_for_status()
        except requests.exceptions.HTTPError as e:
            logger.warning("Failed to put data to cachedaemon: %s", e)


def fill_cachedaemon_with_exts(cachedaemon, responses_dir, allowed_tags=None):
    exts_queue = multiprocessing.Manager().Queue()

    get_exts_workers_number = 8
    put_to_cachedaemon_workers_number = 8

    process_pool = multiprocessing.Pool(get_exts_workers_number)
    get_exts_results = [
        process_pool.apply_async(get_exts, args=(os.path.join(responses_dir, filename), exts_queue, allowed_tags,))
        for filename in os.listdir(responses_dir)
    ]
    process_pool.close()

    put_ext_to_cachedaemon_processes = [
        multiprocessing.Process(
            target=put_ext_to_cachedaemon,
            args=(exts_queue, cachedaemon),
        )
        for _ in range(put_to_cachedaemon_workers_number)
    ]

    with cachedaemon:
        for process in put_ext_to_cachedaemon_processes:
            process.start()

        # Check for exceptions in get_exts() workers
        for async_result in get_exts_results:
            try:
                async_result.get()
            except:
                process_pool.terminate()
                for process in put_ext_to_cachedaemon_processes:
                    process.terminate()
                raise

        logger.debug("Finished fetching exts from responses")
        for _ in range(put_to_cachedaemon_workers_number):
            exts_queue.put(END_OF_QUEUE_SENTINEL)

        for process in put_ext_to_cachedaemon_processes:
            process.join()
            logger.debug("Process %s joined", process.name)

        logger.debug("Finished filling cachedaemon")
