# -*- coding: utf-8 -*-

import json
import logging
import os
import threading
import time
from collections import defaultdict

from .utils import get_job_logger
from .config import (
    LOGTYPES,
    LOCKER_TRANSACTION_TIMEOUT, WAIT_LOCK_TIMEOUT,
    YT_REQUEST_RETRIES, YT_REQUEST_TIMEOUT, YT_HEAVY_REQUEST_TIMEOUT,
)

MANAGER_STATE = 'run_manager'
READER_STATE = 'run_reader'
NAMESPACE = 'yabs-lbyt-reader'
TASKS_TABLE_SCHEMA = [
    {'name': 'num', 'type': 'int64', 'sort_order': 'ascending'},
    {'name': 'time_mark', 'type': 'int64'},
    {'name': 'state', 'type': 'string'},
    {'name': 'data', 'type': 'any'},
    {'name': 'meta', 'type': 'any'},
]

MAX_DATA_PER_REQUEST = 100


class DataProxy(object):
    requests = defaultdict(set)
    lock = threading.Lock()

    @classmethod
    def load_requested(cls):
        with cls.lock:
            while cls.requests:
                job_client, requested = cls.requests.popitem()
                if not cls._do_load(job_client, list(requested)):
                    return False
        return True

    @staticmethod
    def _do_load(job_client, objects):
        logger = get_job_logger()
        logger.info('Reading data for %d tasks from cluster %s', len(objects), job_client.name)
        for start in xrange(0, len(objects), MAX_DATA_PER_REQUEST):
            data_proxies = {
                data_proxy.num: data_proxy
                for data_proxy in objects[start:start + MAX_DATA_PER_REQUEST]
            }
            try:
                for row in job_client.get_yt_connection().lookup_rows(
                    job_client.tasks_table,
                    [{'num': num} for num in data_proxies],
                ):
                    data_proxy = data_proxies[row['num']]
                    data_proxy.data.update(row['data'])
                    data_proxy.loaded = True
            except Exception:
                logger.exception(
                    'Failed to read data from cluster %s',
                    job_client.name,
                )
                return False
        return True

    def __init__(self, job_client, num):
        self.job_client = job_client
        self.loaded = False
        self.num = num
        self.data = {}

    def request(self):
        with self.lock:
            if not self.loaded:
                self.requests[self.job_client].add(self)
        return self.data

    def get(self):
        with self.lock:
            if not self.loaded:
                raise RuntimeError('Trying to get not loaded data')
        return self.data


class TaskInfo(object):
    def __init__(self, job_client, row):
        self.__dict__.update(row)
        self.data_proxy = DataProxy(job_client, self.num)


class JobYtClient(object):

    @staticmethod
    def configure_yt_client(proxy, token):
        from yt.wrapper import YtClient

        yt_logger = logging.getLogger('Yt')
        yt_logger.setLevel(logging.DEBUG)
        yt_logger.propagate = True

        retries_config = {
            'count': YT_REQUEST_RETRIES,
            'backoff': {
                'policy': 'rounded_up_to_request_timeout',
            },
        }

        return YtClient(
            proxy=proxy,
            token=token,
            config={
                'dynamic_table_retries': retries_config,
                'proxy': {
                    'retries': retries_config,
                    'request_timeout': YT_REQUEST_TIMEOUT * 1000,
                    'heavy_request_timeout': YT_HEAVY_REQUEST_TIMEOUT * 1000,
                },
            },
        )

    def __init__(self, job, cluster, path_prefix, get_token_func):
        if len(path_prefix) < 3 or path_prefix[:2] != '//':
            raise ValueError(
                'Invalid yt path, must be absolute (i.e. starts with "//"): ' + path_prefix
            )
        self.localmode = os.path.isfile('/etc/testing')
        self.name = cluster
        if self.localmode:
            self.token = None
            self.proxy = 'localhost:9020'
            self.path_prefix = '//' + self.name + path_prefix[1:]
        else:
            self.token = get_token_func()
            self.proxy = cluster + '.yt.yandex.net'
            self.path_prefix = path_prefix
        self.tasks_table = '/'.join([self.path_prefix, NAMESPACE, job])
        self.job_control = '/'.join([self.path_prefix, NAMESPACE, 'control', job])
        self.lock_prefix = '/'.join([self.path_prefix, NAMESPACE, 'lock', job])
        self.errors = '/'.join([self.path_prefix, NAMESPACE, 'errors'])
        self.cache = '/'.join([self.path_prefix, NAMESPACE, 'cache'])
        self.input = '/'.join([self.path_prefix, NAMESPACE, 'input'])
        self.control_state = None
        self.tasks = None
        self._to_write = []
        self._yt = None

    @property
    def prepared_tasks(self):
        return len(self._to_write)

    @property
    def yt(self):
        if self._yt is None:
            logger = get_job_logger()
            logger.info('Connecting to yt cluster %s via %s', self.name, self.proxy)
            try:
                self._yt = self.get_yt_connection()
            except Exception:
                logger.exception('Failed to create yt connection for cluster %s', self.name)
                raise RuntimeError('No connection')
        return self._yt

    def get_yt_connection(self):
        return self.configure_yt_client(self.proxy, self.token)

    def read_tasks(self):
        logger = get_job_logger()
        logger.info('Reading tasks from yt cluster %s', self.name)
        self.tasks = None  # invalidate cache
        self._to_write = []  # state changed, avoid writing old data
        try:
            self._ensure_tasks_table()
            queue = 'num,time_mark,state from [{}]'.format(self.tasks_table)
            self.tasks = [TaskInfo(self, row) for row in self.yt.select_rows(queue)]
            logger.info('Successfully read %d tasks from yt cluster %s', len(self.tasks), self.name)
        except Exception:
            logger.exception('Failed to read tasks from yt cluster %s', self.name)

    def read_control_state(self):
        logger = get_job_logger()
        logger.info('Reading control state from yt cluster %s', self.name)
        self.control_state = {}  # invalidate cache
        try:
            if not self.yt.exists(self.job_control):
                with self.yt.Transaction():
                    self.yt.create('document', self.job_control, recursive=True)
                    # Only manager is active by default
                    new_state = json.dumps({MANAGER_STATE: True, READER_STATE: False})
                    self.yt.set(self.job_control, new_state, format='json')
            self.control_state = self.yt.get(self.job_control)
        except Exception:
            logger.exception('Failed to read job control state from yt cluster %s', self.name)
        logger.info('Successfully read control state from yt cluster %s', self.name)

    def add_task(self, new_task):
        if 'num' not in new_task:
            raise ValueError(
                'Column "num" is a key for tasks, but it is not specified in {}'.format(new_task)
            )
        for task in self._to_write:
            if task['num'] == new_task['num']:
                task.update(new_task)
                return
        self._to_write.append(new_task)

    def mark_tasks(self, to_mark, state):
        for task_num in to_mark:
            found = False
            for task in self._to_write:
                if task['num'] == task_num:
                    task['state'] = state
                    found = True
                    break
            if not found:
                self._to_write.append({'num': task_num, 'state': state})

    def flush_tasks(self):
        if not self._to_write:
            return True
        try:
            self.yt.insert_rows(self.tasks_table, self._to_write, update=True)
            # Ensure we will not try to write same tasks again.
            self._to_write = []
        except Exception:
            get_job_logger().exception('Failed to write tasks to yt cluster %s', self.name)
            return False
        return True

    def delete_tasks(self, to_delete):
        if not to_delete:
            return True
        try:
            self.yt.delete_rows(self.tasks_table, ({'num': task_num} for task_num in to_delete))
            # Ensure we will not try to rewrite deleted tasks.
            self._to_write = [
                task for task in self._to_write
                if task['num'] not in to_delete
            ]
        except Exception:
            get_job_logger().exception('Failed to delete tasks from yt cluster %s', self.name)
            return False
        return True

    def move_tables(self, move_pairs):
        if not move_pairs:
            return True
        try:
            with self.yt.Transaction():
                for src, dst in move_pairs:
                    self.yt.move(src, dst, recursive=True)
        except Exception:
            get_job_logger().exception('Failed to move tables on yt cluster %s', self.name)
            return False
        return True

    def remove_tables(self, to_remove):
        if not to_remove:
            return True
        try:
            with self.yt.Transaction():
                for table in to_remove:
                    self.yt.remove(table)
        except Exception:
            get_job_logger().exception('Failed to remove tables on yt cluster %s', self.name)
            return False
        return True

    def _ensure_tasks_table(self):
        if self.yt.exists(self.tasks_table):
            return
        self.yt.create_table(self.tasks_table, recursive=True, attributes={
            'schema': TASKS_TABLE_SCHEMA,
            'dynamic': True,
        })
        try:
            self.yt.mount_table(self.tasks_table)
            end_time = time.time() + YT_HEAVY_REQUEST_TIMEOUT
            while self.yt.get(self.tasks_table + '/@tablet_state') != 'mounted':
                if time.time() > end_time:
                    raise RuntimeError('Timed out mount_table')
                time.sleep(1)
        except:
            self.yt.remove(self.tasks_table)
            raise


class YtQuorumLock(object):
    def __init__(self, yt_clusters, lock_name):
        self.yt_clusters = yt_clusters
        self.lock_name = lock_name
        self.transactions = {}
        self.check_quorum_state()

    def __enter__(self):
        if self.transactions:
            raise RuntimeError('YtQuorumLock is in invalide state, probably there was exception in __exit__')
        try:
            thread_lock = threading.Lock()
            threads = [
                threading.Thread(target=self._get_lock_transaction, args=(yt_cluster, thread_lock))
                for yt_cluster in self.yt_clusters
            ]
            for thread in threads:
                thread.daemon = True
                thread.start()
            for thread in threads:
                thread.join()
        except:
            self._finish_transactions()
            raise
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self._finish_transactions()

    def check_quorum_state(self):
        quorum_state = {}
        logger = get_job_logger()
        for yt_cluster in self.yt_clusters:
            lock_path = '{}:{}'.format(yt_cluster.lock_prefix, self.lock_name)
            try:
                locks = yt_cluster.yt.get_attribute(lock_path, 'locks')
            except Exception as err:
                logger.warn('Check lock %s at %s failed: %s', lock_path, yt_cluster.name, err)
                continue
            my_tx = self.transactions.get(yt_cluster.name)
            cluster_state = {'under_my_lock': False, 'locks_count': 0}
            for lock in locks:
                if lock['mode'] != 'exclusive':
                    continue
                if my_tx is not None and lock['transaction_id'] == my_tx.transaction_id:
                    cluster_state['under_my_lock'] = True
                cluster_state['locks_count'] += 1
            if my_tx is not None and not cluster_state['under_my_lock']:
                logger.warn(
                    'Lock %s on yt cluster %s has gone, relocking',
                    lock_path, yt_cluster.name,
                )
                if self._get_lock_transaction(yt_cluster, waitable=False):
                    cluster_state['under_my_lock'] = True
                    cluster_state['locks_count'] += 1
            quorum_state[yt_cluster.name] = cluster_state
        logger.debug('Quorum lock state: %s', quorum_state)
        self.quorum_state = quorum_state

    def have_queue_on_quorum(self):
        have_queue = sum(
            1 for cluster_state in self.quorum_state.values()
            if cluster_state['locks_count'] > 1
        )
        return have_queue >= self._quorum

    def have_quorum(self):
        have_lock = sum(
            1 for cluster_state in self.quorum_state.values()
            if cluster_state['under_my_lock']
        )
        return have_lock >= self._quorum

    def have_waiting_task(self):
        have_waiting = sum(
            1 for cluster_state in self.quorum_state.values()
            if cluster_state['locks_count'] > (1 if cluster_state['under_my_lock'] else 0)
        )
        return have_waiting >= self._quorum

    @property
    def _quorum(self):
        return len(self.yt_clusters) / 2 + 1

    def _get_lock_transaction(self, yt_cluster, thread_lock=None, waitable=True):
        transaction = None
        try:
            # We need separate connection for lock because we don't
            # want to abort other transactions in case of fail
            locker = yt_cluster.get_yt_connection()
            # Ensure path exists
            lock_path = '{}:{}'.format(yt_cluster.lock_prefix, self.lock_name)
            locker.create('uint64_node', lock_path, recursive=True, ignore_existing=True)
            # Start transaction
            transaction = locker.Transaction(
                timeout=(LOCKER_TRANSACTION_TIMEOUT * 1000),
                interrupt_on_failed=False,
            )
            # Try to lock
            with locker.Transaction(transaction_id=transaction.transaction_id):
                locker.lock(lock_path, waitable=waitable, wait_for=(WAIT_LOCK_TIMEOUT * 1000))
        except:
            get_job_logger().exception('Failed to get lock on yt cluster %s', yt_cluster.name)
            # Cluster is not available, abort transaction if it was already started
            if transaction:
                transaction.abort()
            return False
        else:
            if thread_lock is not None:
                with thread_lock:
                    self.transactions[yt_cluster.name] = transaction
            else:
                self.transactions[yt_cluster.name] = transaction
            return True

    def _forget_cluster_transaction(self, cluster_name):
        if cluster_name not in self.transactions:
            return
        self.transactions[cluster_name].abort()
        del self.transactions[cluster_name]

    def _finish_transactions(self):
        for yt_cluster in self.yt_clusters:
            self._forget_cluster_transaction(yt_cluster.name)


def get_task_offsets(task_data, for_next):
    return {
        logtype: {
            '{}:{}'.format(row['topic'], row['partition']):
                row['offset'] + (row['limit'] if for_next else 0)
            for row in task_data[logtype]
        }
        for logtype in task_data
    }


def get_table_name(job, task_info, logtype):
    return '{}.{:08}.{}'.format(
        job,
        task_info.num,
        LOGTYPES.get(logtype, {}).get('lb_name', logtype),
    )


def parse_table_name(table_name):
    parts = table_name.split('.')
    assert len(parts) in (3, 4), 'Unexpected table name format'
    return {
        'job': parts[0],
        'num': int(parts[1]),
        'logtype': parts[2],
    }
