# -*- coding: utf-8 -*-
import yt.wrapper as yt_wrapper
from datacloud.dev_utils.logging.logger import get_basic_logger

logger = get_basic_logger(__name__)

__all__ = [
    'Memory',
    'Yt'
]


class RecordStorage(object):
    def __init__(self):
        pass

    def set_data(self, harvesting, recs):
        pass

    def get_data(self, harvesting):
        pass


class Memory(RecordStorage):
    def __init__(self):
        super(Memory, self).__init__()
        self._data = {}

    def get_data(self, harvesting):
        if harvesting.rel_path in self._data:
            return self._data[harvesting.rel_path]
        return None

    def set_data(self, harvesting, recs):
        self._data[harvesting.rel_path] = recs


class Yt(RecordStorage):
    def __init__(self, yt_client, yt_root):
        super(Yt, self).__init__()
        self._yt_client = yt_client
        self.yt_root = yt_root

    def get_data(self, harvesting):
        table = self._get_yt_table_path(harvesting)
        if self._yt_client.exists(table):
            return self._yt_client.read_table(table)
        return None

    def set_data(self, harvesting, recs):
        table = self._get_yt_table_path(harvesting)
        if self._yt_client.exists(table):
            logger.warn('[STORAGE] table {} exists. SKIPPED'.format(table))
        folder = self._get_yt_folder(harvesting)
        with self._yt_client.Transaction():
            if not self._yt_client.exists(folder):
                self._yt_client.create('map_node', folder, recursive=True)
            self._yt_client.write_table(table, recs)

    def _get_yt_table_path(self, harvesting):
        return yt_wrapper.ypath_join(self.yt_root, harvesting.rel_path)

    def _get_yt_folder(self, harvesting):
        return '/'.join(self._get_yt_table_path(harvesting).split('/')[:-1])
