import logging

from copy import deepcopy
import datetime
from six import string_types, next
from six.moves import zip_longest

try:
    import yt.wrapper as yt_wrapper
    import yt.wrapper.format as yt_format
    from yt.wrapper.errors import YtTabletNotMounted

    if yt_format.yson.TYPE == 'BINARY':
        yt_format = yt_wrapper.YsonFormat()
    else:
        logging.warning('Please add yt/python/yt_yson_bindings to ya.make for YSON feature')
        yt_format = yt_wrapper.JsonFormat()

except ImportError:
    yt_wrapper = None
    yt_format = None
    logging.warning('If you want to use YT, please add yt/python/client to ya.make')

from .base import KeyValueStore
from . import error

INPUT_ROW_LIMIT = 100000000000
OUTPUT_ROW_LIMIT = 10000000

log = logging.getLogger(__name__)


class DynTableKeyValueStore(KeyValueStore):
    def __init__(self, namespaces, keys_cols, data_cols):
        super(DynTableKeyValueStore, self).__init__(namespaces, keys_cols, data_cols)
        self._data = {}
        self._table_path = None
        self._proxy = None
        self._client = None

        self._config_diff = None

    def table_path(self, client):
        return self._table_path

    def _read_table(self, *args, **kwargs):
        return self._client.read_table(self.table_path(self._client), *args, **kwargs)

    def _select_rows(self, *args, **kwargs):
        # TODO modify sql with support get table_path for cluster
        return self._client.select_rows(*args, **kwargs)

    def _lookup_rows(self, *args, **kwargs):
        return self._client.lookup_rows(self.table_path(self._client), *args, **kwargs)

    def _insert_rows(self, data):
        return self._client.insert_rows(self.table_path(self._client), data, aggregate=True)

    def _delete_rows(self, *args, **kwargs):
        return self._client.delete_rows(self.table_path(self._client), *args, **kwargs)

    def _mount_table(self, *args, **kwargs):
        self._client.mount_table(self.table_path(self._client), *args, **kwargs)

    def _unmount_table(self, *args, **kwargs):
        self._client.unmount_table(self.table_path(self._client), *args, **kwargs)

    def _create(self, schema):
        self._client.create('table', self.table_path(self._client), attributes={'schema': schema, 'dynamic': True})

    def _exists(self):
        return self._client.exists(self.table_path(self._client))

    def _remove(self):
        self._client.remove(self.table_path(self._client))

    def transaction(self):
        if self._client is None:
            log.error('YT client is None')
            raise RuntimeError('YT client is None')
        log.info('Start transaction')
        return self._client.Transaction(type='tablet')

    def select(self, query, allow_join_without_index=False, remount=True):
        """Call yt client's select_rows"""
        log.debug('Running query {!r}'.format(query))
        try:
            return self._select_rows(query, format=yt_format, fail_on_incomplete_result=False,
                                     allow_join_without_index=allow_join_without_index,
                                     allow_full_scan=True, input_row_limit=INPUT_ROW_LIMIT, output_row_limit=OUTPUT_ROW_LIMIT)
        except YtTabletNotMounted:
            log.warning('Table for sql request %s was unmounted', query)
            if not remount:
                raise
            self.mount()
            return self._select_rows(query, format=yt_format,
                                     allow_join_without_index=allow_join_without_index)

    def upsert(self, namespace, key, data):
        return self.bulk_upsert([(namespace, key, data)])

    def bulk_upsert(self, objects):
        bulk_data = []
        for namespace, key, data in objects:
            self.validate_namespace(namespace)
            self.validate_key(key)
            data.update({k: v for k, v in zip_longest(self.namespaces, namespace)})
            data.update({k: v for k, v in zip(self.keys_cols, key)})
            data.update({k: None for k in self.data_cols if k not in data})
            bulk_data.append(data)

        return self._insert_rows(bulk_data)

    def read(self, namespace, key, column_list=None):
        super(DynTableKeyValueStore, self).read(namespace, key, column_list)

        total_keys = self.namespaces + self.keys_cols

        request = {x: None for x in self.namespaces}
        request.update({x: y for x, y in zip(self.keys_cols, key)})

        requests = []
        frozen_requests = []

        for n in range(len(namespace)):
            request = deepcopy(request)
            request.update({self.namespaces[n]: namespace[n]})
            requests.append(request)
            frozen_requests.append(frozenset(((x, request[x]) for x in request)))

        answer = {}

        column_names = None
        if column_list is not None:
            column_names = total_keys + tuple(column_list)

        for row in self._lookup_rows(requests, column_names=column_names, format=yt_format):
            answer[frozenset(((x, row[x]) for x in total_keys))] = row

        result = {x: None for x in self.namespaces}
        request.update({x: y for x, y in zip(self.keys_cols, key)})

        for r in frozen_requests:
            if r in answer:
                result.update({k: answer[r][k] for k in answer[r] if answer[r][k] is not None})
        for dc in self.data_cols:
            if dc in result and result[dc] is not None:
                return result

        raise error.NoSuchNamespaceException(namespace, key)

    def filter(self, request=None):
        if request is None:
            log.info('request is None, getting all data after remount table')
            self.unmount()
            self.mount()
            log.debug('Remount finished. Now all data must be available in read_table')
            for row in self._read_table():
                yield row
        else:
            if request.column_list:
                column_list = request.column_list | set(self.namespaces) | set(self.keys_cols)
            else:
                column_list = set(self.data_cols + self.namespaces + self.keys_cols)

            log.debug('Columns to read from YT: %s', column_list)

            requests = {}

            def _ns(obj):
                return tuple(obj[n] for n in self.namespaces if obj[n] is not None)

            def _key(obj):
                return tuple(obj[n] for n in self.keys_cols)

            log.info('Preparing YT request')
            for key in request.request:
                log.debug('Processing key %s', key)
                yt_req = {x: y for x, y in zip(self.keys_cols, key)}

                for namespaces in request.request[key]:
                    requests[(key, namespaces)] = dict(yt_req, **dict(zip_longest(self.namespaces, namespaces)))

            done = set()
            while requests:
                log.info('Making YT request with size: %s', len(requests))
                new_requests = {}
                for row in self._lookup_rows(requests.values(), column_names=list(column_list), format=yt_format):
                    k = (_key(row), _ns(row))
                    log.debug('Get data for key %s and namespace %s', k[0], k[1])
                    done.add(k)
                    requests.pop(k)
                    yield {k: row[k] for k in row if k in (request.column_list or column_list)}

                log.info('Preparing YT request')
                for key, ns in requests:
                    ind = len(ns)
                    if ind > 1:
                        req = requests[(key, ns)]
                        req[self.namespaces[ind - 1]] = None
                        ns = ns[0:ind - 1]
                        if (key, ns) not in done:
                            new_requests[(key, ns)] = req

                requests = new_requests

    def update(self, namespace, key, data):
        super(DynTableKeyValueStore, self).update(namespace, key, data)

        request = {k: v for k, v in zip_longest(self.namespaces, namespace)}
        request.update({x: y for x, y in zip(self.keys_cols, key)})

        origin = self._lookup_rows([request], format=yt_format)
        try:
            origin = next(origin)
        except StopIteration:
            raise error.NoSuchNamespaceException(namespace, key)

        origin.update(data)
        return self._insert_rows([origin])

    def delete(self, namespace, key):
        super(DynTableKeyValueStore, self).delete(namespace, key)
        request = {k: v for k, v in zip_longest(self.namespaces, namespace)}
        request.update({x: y for x, y in zip(self.keys_cols, key)})
        for row in self._lookup_rows([request], format=yt_format):
            self._delete_rows([request], format=yt_format)
            return row

        raise error.NoSuchNamespaceException(namespace, key)

    def init_client(self, table_path_proxy, token=None):
        if yt_wrapper is None:
            raise error.NoYtModuleException()

        table_path, proxy = table_path_proxy

        self._table_path = table_path

        if isinstance(proxy, yt_wrapper.YtClient):
            self._client = proxy
            self._proxy = self._client.config['proxy']['url']
            self._config_diff = self.find_config_diff(self._client.config)
            self._config_diff.pop('token', None)
        else:
            self._proxy = proxy
            self._client = yt_wrapper.YtClient(self._proxy,
                                               token,
                                               config={
                                                   'backend': 'rpc',
                                                   'dynamic_table_retries': {'enable': False}})

    TYPES = {
        int: 'int64',
        float: 'double',
        bool: 'boolean',
        str: 'string',
        datetime.datetime: 'uint64',
        datetime.date: 'uint64',
        datetime.time: 'uint64',
    }

    def mount(self):
        self._mount_table(sync=True)

    def unmount(self):
        self._unmount_table(sync=True)

    def create_table(self, schema):
        self._create(schema)

    def remove_table(self):
        self._remove()

    def exists_table(self):
        return self._exists()

    def _to_schema_type(self, t):
        if isinstance(t, string_types):
            return t
        elif t in self.TYPES:
            return self.TYPES[t]
        else:
            return 'any'

    def build_schema(self, namespace_types, key_types, data_types):
        # TODO: Respect user-defined ordering. i.e. merge namespace and key_types to alow fine-tuning of
        # indices on the table: MULTIK-90
        schema = []

        if len(self.namespaces) != len(namespace_types):
            raise ValueError('Wrong size of namespace types')
        if len(self.keys_cols) != len(key_types):
            raise ValueError('Wrong size of key types')
        if len(self.data_cols) != len(data_types):
            raise ValueError('Wrong size of data types')

        for n, t in zip(self.keys_cols, key_types):
            schema.append({'name': n, 'type': self._to_schema_type(t), 'sort_order': 'ascending'})

        for n, t in zip(self.namespaces, namespace_types):
            schema.append({'name': n, 'type': self._to_schema_type(t), 'sort_order': 'ascending'})

        for n, t in zip(self.data_cols, data_types):
            schema.append({'name': n, 'type': self._to_schema_type(t)})

        return schema

    def init_store(self, table_path_proxy, token=None):
        self.init_client(table_path_proxy, token)

    def connect_store(self):
        self.mount()

    def create_store(self, namespace_types, key_types, data_types):
        schema = self.build_schema(namespace_types, key_types, data_types)
        self.create_table(schema)

    def drop_store(self):
        self.unmount()
        self.remove_table()

    @classmethod
    def from_dict(cls, data, path=None):
        kv = cls(tuple(data['namespaces']), tuple(data['keys_cols']), tuple(data['data_cols']))
        if data['config_diff'] is not None:
            client = yt_wrapper.YtClient(config=data['config_diff'])
            kv.init_store((data['table_path'], client))
            kv.connect_store()
        elif data['table_path'] is not None:
            kv.init_store((data['table_path'], data['proxy']))
            kv.connect_store()
        return kv

    @classmethod
    def find_config_diff(cls, config, default=None):
        default = default if default is not None else yt_wrapper.default_config.get_default_config()
        keys = config.keys()
        ans = {}
        for k in keys:
            if isinstance(config[k], yt_wrapper.mappings.VerifiedDict):
                tmp = cls.find_config_diff(config[k], default[k])
                if tmp:
                    ans[k] = tmp
            elif config[k] != default[k]:
                ans[k] = config[k]
        return ans

    def dict(self, *args):
        return super(DynTableKeyValueStore, self).dict({
            'table_path': self._table_path,
            'proxy': self._proxy,
            'config_diff': self._config_diff
        }, *args)

    def files(self):
        return []
