import json
import logging
import os
import re
import shutil
import tarfile
from collections import OrderedDict
from multiprocessing import Manager
from tempfile import mkdtemp

import cPickle as pickle
import numpy as np
import torch
from catboost import CatBoostClassifier
from flask import current_app as app
from scipy import sparse

from jafar.storages.base import Storage
from jafar.storages.exceptions import StorageKeyError
from jafar.storages.flask_jafar_storage import StorageWrapper
from jafar.utils.cmph_wrapper import load_hash
from jafar.utils.static_dict import StaticDict, StaticMapping

ARRAY_TYPE = 'array'
DICT_TYPE = 'dict'
MAPPING_TYPE = 'mapping'
SPARSE_MATRIX = 'sparse_matrix'
CATBOOST_TYPE = 'catboost'
NN_WEIGHTS_TYPE = 'nn_weights'

logger = logging.getLogger(__name__)


def is_shape_of_empty_array(shape):
    """
    Given numpy array's `shape` tuple, returns True, if
    it's a shape of an empty array and False otherwise.

    Shapes can have zero dimensions in case of scalars (`np.array(1)`),
    in which case it's always non-empty. Otherwise, an array is
    empty if the last dimension equals to zero (`np.array([[[]]])`)
    """
    if not shape:
        return False
    return shape[-1] == 0


class MemmapStorage(Storage):
    allow_parallel_access = False

    def __init__(self):
        if not os.path.exists(self.base_dir):
            os.makedirs(self.base_dir)
        self._storage = {}
        self.init_meta()
        self.load_meta()

    def init_meta(self):
        self.meta = {}

    def load_meta(self):
        if os.path.exists(self.meta_path):
            with open(self.meta_path, 'r') as f:
                self.meta.update(pickle.load(f))
            if app.config['MEMMAP_STORAGE_MODE'] == app.config['MEMMAP_STORAGE_EAGER_MODE']:
                for key in self.meta.keys():
                    self._storage[key] = self._load_key(key)

    @property
    def base_dir(self):
        return app.config['MEMMAP_STORAGE_DIR']

    @property
    def meta_path(self):
        return os.path.join(self.base_dir, 'meta.pkl')

    def _load_key(self, key):
        logger.debug('Loading %s', key)
        key_meta = self.meta[key]
        key_type = key_meta['type']
        loaders = {
            ARRAY_TYPE: self._load_array,
            MAPPING_TYPE: self._load_mapping,
            DICT_TYPE: self._load_dict,
            SPARSE_MATRIX: self._load_sparse_matrix,
            CATBOOST_TYPE: self._load_catboost,
            NN_WEIGHTS_TYPE: self._load_nn_weights
        }
        if key_type not in loaders:
            raise TypeError('Got unknown key type %s' % key_type)
        return loaders[key_type](key, key_meta)

    def _load_array(self, key, key_meta):
        shape = key_meta['shape']
        if is_shape_of_empty_array(shape):
            return np.empty(dtype=key_meta['dtype'], shape=shape)
        path = self.get_path(key)
        if key_meta['dtype'].hasobject:
            with open(path) as f:
                return pickle.load(f)
        else:
            return np.memmap(path, dtype=key_meta['dtype'], mode='c', shape=shape)

    def _load_mapping(self, key, key_meta):
        hash_function = self._load_hash_function(key)
        array = self._load_key(key + '.array')
        return StaticMapping(array, hash_function)

    def _load_dict(self, key, key_meta):
        key_map = self._load_key(key + '.key_map')
        array = self._load_key(key + '.values')
        keys = self._load_key(key + '.keys')
        serialization = key_meta['serialization']
        return StaticDict(key_map, keys, array, serialization)

    def _load_sparse_matrix(self, key, key_meta):
        shape = key_meta['shape']
        fmt = key_meta['format']
        data = self._load_key(key + '.data')
        indices = self._load_key(key + '.indices')
        indptr = self._load_key(key + '.indptr')
        matrix_class = getattr(sparse, fmt + '_matrix')
        return matrix_class((data, indices, indptr), shape=shape)

    def _load_catboost(self, key, key_meta):
        result = CatBoostClassifier()
        path = self.get_path(key)
        result.load_model(path)
        return result

    def _load_nn_weights(self, key, key_meta):
        path = self.get_path(key)
        weights = torch.load(path, map_location='cpu')
        return weights

    def update_meta_key(self, key, **kwargs):
        if key not in self.meta:
            self.meta[key] = {}
        self.meta[key].update(**kwargs)

    def _get_hash_path(self, key):
        return os.path.join(self.base_dir, key + '.hash')

    def get_path(self, key):
        return os.path.join(self.base_dir, key)

    def _load_hash_function(self, key):
        path = self._get_hash_path(key)
        if os.path.exists(path):
            return load_hash(path)

    def _store_hash_function(self, key, hash_function):
        if hash_function:
            # NOTE: None hash_function is possible
            path = self._get_hash_path(key)
            hash_function.dump(path)

    def remove(self, key):
        logger.debug("Removing %s", key)
        del self.meta[key]
        try:
            os.remove(self.get_path(key))
        except OSError:
            logger.warn("Couldn't find key %s on disk; already removed?", key)
        self.update_meta()

    def store(self, key, value):
        logger.debug("Storing %s", key)
        if isinstance(value, sparse.spmatrix):
            self.store_sparse_matrix(key, value)
        elif isinstance(value, OrderedDict) \
                and value \
                and isinstance(value.values()[0], torch.Tensor):
            self.store_nn_weights(key, value)
        elif isinstance(value, dict):
            self.store_dict(key, value)
        elif isinstance(value, CatBoostClassifier):
            self.store_catboost(key, value)
        else:  # assume other types can be converted to array
            self.store_array(key, np.array(value))
        self._storage[key] = self._load_key(key)
        self.update_meta()

    def store_array(self, key, value):
        path = self.get_path(key)
        if value.dtype.hasobject:
            # object arrays are pickled
            with open(path, 'w') as f:
                pickle.dump(value, f)
        elif value.size > 0:
            mapped_array = np.memmap(path, dtype=value.dtype, mode='w+', shape=value.shape)
            mapped_array[...] = value
            mapped_array.flush()
            del mapped_array  # file is closed on delete
        self.update_meta_key(key, dtype=value.dtype, shape=value.shape, type=ARRAY_TYPE)

    def store_sparse_matrix(self, key, value):
        assert isinstance(value, sparse.spmatrix)
        assert value.format in ('csc', 'csr'), 'Only compressed CSR and CSC matrixes are supported'
        self.store_array(key + '.data', value.data)
        self.store_array(key + '.indices', value.indices)
        self.store_array(key + '.indptr', value.indptr)
        self.update_meta_key(key, shape=value.shape, format=value.format, type=SPARSE_MATRIX)

    def store_mapping(self, key, mapping):
        self.update_meta_key(key, type=MAPPING_TYPE)
        self._store_hash_function(key, mapping.hash_function)
        self.store_array(key + '.array', mapping.array)

    def store_dict(self, key, dictionary):
        # dicts are converted to instances of jafar.utils.static_dict.StaticDict
        static_dict = StaticDict.create(dictionary.keys(), dictionary.values())
        self.update_meta_key(key, serialization=static_dict.serialization, type=DICT_TYPE)
        self.store_array(key + '.values', static_dict.values_array)
        self.store_array(key + '.keys', static_dict.keys_array)
        self.store_mapping(key + '.key_map', static_dict.key_map)

    def store_catboost(self, key, value):
        path = self.get_path(key)
        value.save_model(path)
        self.update_meta_key(key, type=CATBOOST_TYPE)

    def store_nn_weights(self, key, value):
        path = self.get_path(key)
        torch.save(value, path)
        self.update_meta_key(key, type=NN_WEIGHTS_TYPE)

    def update_meta(self):
        with open(self.meta_path, 'w') as f:
            pickle.dump(dict(self.meta), f, protocol=pickle.HIGHEST_PROTOCOL)

    def get_object(self, key):
        try:
            if key not in self._storage:
                self._storage[key] = self._load_key(key)
            value = self._storage[key]
            if isinstance(value, StaticDict) or isinstance(value, StaticMapping):
                return dict(value)
            return value
        except KeyError:
            raise StorageKeyError(key)

    def get_proxy(self, key):
        try:
            if key not in self._storage:
                self._storage[key] = self._load_key(key)
            return self._storage[key]
        except KeyError:
            raise StorageKeyError(key)

    def has_key(self, key):
        return key in self.meta

    def get_dict_values(self, key, dict_keys):
        proxy = self.get_proxy(key)
        return np.array([proxy.get(k) for k in dict_keys], dtype=proxy.values_dtype)

    def get_matrix_rows(self, key, idx):
        return self.get_proxy(key)[idx]

    def dump(self, file_name):
        with tarfile.open(file_name, "w:gz") as tar:
            # there are some issues then iterating self.meta like an ordinary dict
            for key, key_meta in self.meta.items():
                if key_meta['type'] == ARRAY_TYPE and not is_shape_of_empty_array(key_meta['shape']):
                    logger.debug("Dumping array '%s'", key)
                    tar.add(self.get_path(key), arcname=key)
                elif key_meta['type'] == MAPPING_TYPE:
                    hash_path = self._get_hash_path(key)
                    logger.debug("Dumping hash '%s'", key + '.hash')
                    tar.add(hash_path, arcname=key + '.hash')
                elif key_meta['type'] == CATBOOST_TYPE:
                    logger.debug("Dumping CatBoost %s", key)
                    tar.add(self.get_path(key), arcname=key)
                elif key_meta['type'] == NN_WEIGHTS_TYPE:
                    logger.debug("Dumping NN Module %s", key)
                    tar.add(self.get_path(key), arcname=key)

            logger.debug("Dumping meta.pkl")
            tar.add(self.meta_path, arcname='meta.pkl')

    def make_mapping(self, key, values):
        mapping = StaticMapping.create(values)
        self.store_mapping(key, mapping)

    def map_values(self, mapping_key, values, default, reverse):
        mapping = self.get_proxy(mapping_key)
        if reverse:
            mapping = mapping.reverse
        return mapping.map(values, default)

    def to_dict(self):
        return self._storage

    @classmethod
    def from_dict(cls, dictionary):
        instance = cls()
        instance._storage = dictionary
        return instance


class MultiprocessMemmapStorage(MemmapStorage):
    """
    Multriprocessing memmap storage. Used mainly for training.
    NOTE: Does not work with gevent sockets
    """
    allow_parallel_access = True

    def __init__(self):
        logger.debug("Creating multiprocessing sync manager primitives")
        self.sync_manager = Manager()
        # RLock is used to re-enter lock in case of updating meta while storing array
        self.lock = self.sync_manager.RLock()
        super(MultiprocessMemmapStorage, self).__init__()

    def init_meta(self):
        self.meta = self.sync_manager.dict()

    def load_meta(self):
        with self.lock:
            super(MultiprocessMemmapStorage, self).load_meta()

    def update_meta(self):
        with self.lock:
            super(MultiprocessMemmapStorage, self).update_meta()

    def store_array(self, key, value):
        with self.lock:
            super(MultiprocessMemmapStorage, self).store_array(key, value)

    def _store_hash_function(self, key, hash_function):
        with self.lock:
            super(MultiprocessMemmapStorage, self)._store_hash_function(key, hash_function)

    def update_meta_key(self, key, **kwargs):
        """ syncmanager dictproxy should be noticed about object change via __setitem__ """
        value = self.meta.get(key, {})
        value.update(**kwargs)
        self.meta[key] = value


class MemmapStorageWrapper(StorageWrapper):
    def connect(self):
        return MemmapStorage()


def parse_version_info(file_name):
    # should be jafar_{version}_{date in YYYY-MM-DDTHH-MM-SS format}.tar.gz
    # example: jafar_1_2017-11-03T07-58-54.tar.gz
    match = re.match(r'jafar_(\d+)_(\d{4}-\d{2}-\d{2}T\d{2}[-:]\d{2}[-:]\d{2})\.tar\.gz', file_name)
    if not match:
        raise ValueError("Expected 'jafar_{version}_{date in YYYY-MM-DDTHH-MM-SS format}.tar.gz' format")
    version, date = match.groups()
    return {
        'version': int(version),
        'date': date
    }


def get_version_info_path():
    return os.path.join(app.config['MEMMAP_STORAGE_DIR'], 'version_info.json')


def get_current_version_info():
    with open(get_version_info_path(), 'r') as f:
        return json.load(f)


def restore_dump(file_name, lazy=False):
    # if lazy flag is on, check whether corresponding version is already installed
    version_info = parse_version_info(os.path.basename(os.path.realpath(file_name)))
    if lazy:
        try:
            if version_info == get_current_version_info():
                logger.info(
                    "Snapshot version %s is already installed, not doing anything because of lazy=True",
                    version_info
                )
                return
        except (ValueError, OSError, IOError):
            # could be parse error, file missing, or permission denied
            pass

    if version_info['version'] != app.config['SNAPSHOT_VERSION']:
        raise ValueError(
            "Snapshot versions mismatch: trying to restore version {}, but version {} is supported".format(
                version_info['version'], app.config['SNAPSHOT_VERSION']
            )
        )

    # cleaning up old content
    logger.info("Extracting snapshot: version %d, %s", version_info['version'], version_info['date'])
    tempdir = mkdtemp()
    with tarfile.open(file_name, "r:gz") as tar:
        tar.extractall(tempdir)
    shutil.rmtree(app.config['MEMMAP_STORAGE_DIR'])
    shutil.move(tempdir, app.config['MEMMAP_STORAGE_DIR'])

    # keep version info for snapshot updates
    with open(get_version_info_path(), 'w') as f:
        json.dump(version_info, f, indent=4)
