import gzip
import logging
import os.path
import re
import shutil
import tempfile

import pymongo
import requests
from requests.adapters import HTTPAdapter
from sandbox import sdk2
from sandbox.projects.music.resources import StaticsOffheapArtifacts
from urllib3.util.retry import Retry

from sandbox.projects.music.deployment.helpers.Config import CONFIG


class StaticsOffheapArtifactsUploader(object):
    @classmethod
    def execute(cls, task, db_host, db_user, db_password, prefix):
        mds_keys_cursor = cls._get_mds_keys(db_host, db_user, db_password, prefix)
        logging.info("Found %s MDS keys to download" % mds_keys_cursor.count())

        resource = StaticsOffheapArtifacts(task, "Artifacts from MDS for TestBigBlob", "mds-mock", ttl=30)
        resource_data = sdk2.ResourceData(resource)
        resource_path = str(resource_data.path)

        sess = cls._prepare_requests_session()
        for doc in mds_keys_cursor:
            logging.debug("Downloading %s" % doc["key"])
            download_path = cls._download_artifact_from_mds(sess, doc)

            if "compressionType" in doc and doc["compressionType"] != "none":
                cls._decompress_artifact(download_path, doc["compressionType"])

            key = doc["key"]
            artifact_save_path = key[key.find("/") + 1:]
            dst = os.path.join(resource_path, artifact_save_path)
            if not os.path.exists(os.path.dirname(dst)):
                os.makedirs(os.path.dirname(dst))
            shutil.move(download_path, dst)

        resource_data.ready()
        return resource.id

    @staticmethod
    def _get_mds_keys(db_host, db_user, db_password, prefix):
        uri = "mongodb://{login}:{password}@{host}".format(login=db_user, password=db_password, host=db_host)
        mongo = pymongo.MongoClient(uri)
        collection = mongo.elliptics_metadata.basecaches

        id_regex = re.compile(r"^statics/%s/face/" % prefix)
        key_regex = re.compile(r"\.mm?$", re.IGNORECASE)
        query = {"$and": [{"_id": id_regex}, {"key": key_regex}]}
        projection = {"key": 1, "_id": 0, "compressionType": 1}
        mds_keys_cursor = collection.find(query, projection)
        return mds_keys_cursor

    @staticmethod
    def _prepare_requests_session():
        sess = requests.Session()
        auth_header = {"Authorization": "Basic %s" % sdk2.Vault.data(CONFIG.mds_basic_token)}
        sess.headers.update(auth_header)
        sess.mount('http://', HTTPAdapter(max_retries=Retry(5)))

        return sess

    @staticmethod
    def _download_artifact_from_mds(sess, doc):
        download_path = tempfile.NamedTemporaryFile().name
        url = "http://storage-int.mds.yandex.net/get-music-blobs/%s" % doc["key"]
        with sess.get(url, stream=True) as r:
            r.raise_for_status()
            with open(download_path, "wb") as f:
                for chunk in r.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)
        return download_path

    @classmethod
    def _decompress_artifact(cls, download_path, compression_type):
        input_file = download_path + "_a"
        shutil.move(download_path, input_file)
        output_file = download_path

        if compression_type == "zstd":
            cls._unzstd_file(input_file, output_file)
        elif compression_type == "gzip":
            cls.ungzip_file(input_file, output_file)
        else:
            logging.error("Decompression algorithm was not implemented for %s" % compression_type)
            raise DecompressionNotImplementedError()
        os.remove(input_file)

    @classmethod
    def _unzstd_file(cls, input_file, output_file):
        import zstandard as zstd

        with open(input_file, 'rb') as f_in, open(output_file, "wb") as f_out:
            zstd.ZstdDecompressor().copy_stream(f_in, f_out, write_size=65536)

    @classmethod
    def ungzip_file(cls, input_file, output_file):
        with gzip.open(input_file, 'rb') as f_in, open(output_file, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)


class DecompressionNotImplementedError(Exception):
    pass
