from yt import wrapper as default_yt
from contextlib import contextmanager

from .constants import CompressionCodec, ErasureCodec, YtCluster, OptimizeFor
from .state_check import NotExists
from .client import create_yt_client


__all__ = [
    'BaseTable',
]


def _maybe_enum(enum, value):
    if value is None:
        return value
    else:
        return enum(value)


class BaseTable(object):
    def __init__(self, path, yt_client=None, yt_cluster=None, optimize_for=None,
                 schema=None, compression_codec=None, erasure_codec=None, media=None):
        super(BaseTable, self).__init__()

        self._yt = yt_client
        self.yt_cluster = _maybe_enum(YtCluster, yt_cluster)
        self.path = path
        self.schema = schema
        self.optimize_for = _maybe_enum(OptimizeFor, optimize_for)
        self.compression_codec = _maybe_enum(CompressionCodec, compression_codec)
        self.erasure_codec = _maybe_enum(ErasureCodec, erasure_codec)
        self.media = media  # or [MediaItem()]
        if self.media is not None:
            self.primary_medium = self.media[0].name
        else:
            self.primary_medium = None

    def __str__(self):
        return str(self.path)

    @property
    def yt_client(self):
        return self._yt if self._yt is not None else default_yt

    def get_create_attributes(self):
        attributes = {}
        if self.schema is not None:
            attributes["schema"] = self.schema
        if self.compression_codec is not None:
            attributes["compression_codec"] = self.compression_codec.value
        if self.erasure_codec is not None:
            attributes["erasure_codec"] = self.erasure_codec.value
        if self.primary_medium is not None:
            attributes["primary_medium"] = self.primary_medium
        if self.optimize_for is not None:
            attributes["optimize_for"] = self.optimize_for.value
        if self.media is not None:
            attributes["media"] = {
                m.name: {
                    'replication_factor': m.replication_factor,
                    'data_parts_only': m.data_parts_only,
                }
                for m in self.media
            }
        return attributes

    def get_attribute(self, attr, default=None):
        return self.yt_client.get_attribute(self.path, attr, default=default)

    def set_attribute(self, attr, value):
        return self.yt_client.set_attribute(self.path, attr, value)

    def compress(self):
        attributes = {}
        if self.erasure_codec is not None:
            attributes["erasure_codec"] = self.erasure_codec.value
        if self.compression_codec is not None:
            attributes["compression_codec"] = self.compression_codec.value
        self.yt_client.transform(self.path, **attributes)

    def exists(self):
        return self.yt_client.exists(self.path)

    def remove(self):
        if self.yt_client.exists(self.path):
            return self.yt_client.remove(self.path)

    def create_table(self, force=False, attrs=None):
        attributes = self.get_create_attributes()
        attrs = attrs or {}
        attributes.update(attrs)
        for k, v in attrs.iteritems():
            if v is None:
                attributes.pop(k)
        if force and self.exists():
            self.remove()
        if not self.exists():
            self.yt_client.create('table', self.path, recursive=True, attributes=attributes)

    def alter_table(self, schema=None):
        if not schema:
            return
        self.yt_client.alter_table(self.path, schema=schema)

    def read_table(self, *args, **kwargs):
        return self.yt_client.read_table(self.path, *args, **kwargs)

    def write_table(self, rows, *args, **kwargs):
        return self.yt_client.write_table(self.path, rows, *args, **kwargs)

    def verify_state(self):
        """ Verify current YT table and iterates over StateDifference """
        if not self.exists():
            yield NotExists(self)
            raise StopIteration()

    @contextmanager
    def lock(self, wait_for=None):
        """
        Creates exclusive lock for table

        :param int wait_for: wait interval in milliseconds
        """
        wait_for = wait_for or 1000  # wait one second by default to avoid manual lock state check
        assert self.yt_cluster is not None
        yt_client = create_yt_client(self.yt_cluster)
        with yt_client.Transaction(attributes={"title": "xprod-schema lock transaction"}):
            yt_client.lock(self.path, waitable=wait_for is not None, wait_for=wait_for)
            yield

    def ensure_schema_is_valid(self):
        if self.schema is None:
            return
        actual_schema = self.get_attribute("schema")
        if actual_schema is None:
            if not self.exists():
                self.create_table()
            else:
                self.alter_table(schema=self.schema)
        elif not _is_schemas_simular(self.schema, actual_schema):
            self.alter_table(schema=self.schema)

    def merge_chunks(self, size=2 ** 30, mode='sorted'):
        """ Merges small chunks into larger """
        with self.yt_client.Transaction(attributes={"title": "xprod-schema merge transaction"}):
            self.yt_client.run_merge(
                self.path,
                self.path,
                mode=mode,
                spec={'combine_chunks': True, 'data_size_per_job': size}
            )


def _is_schemas_simular(schema1, schema2):
    if len(schema1) != len(schema2):
        return False
    fields1 = frozenset((s['name'], s['type']) for s in schema1)
    fields2 = frozenset((s['name'], s['type']) for s in schema2)
    return fields1 == fields2
