from abc import abstractmethod, ABCMeta
from six import add_metaclass
from multiprocessing.pool import ThreadPool
from multiprocessing import cpu_count, Value
from uuid import uuid4
import datetime
import inspect
import logging
import re
import socket
import sys
import time
import traceback

from psycopg2 import connect
import psycopg2.errors
from pq.tasks import PQ

import django.db as django_db
from irt.utils import suppress, get_deploy_node_fqdn, get_deploy_pod_fqdn, get_deploy_box_ipv6
from irt.init_subclass import suppress_unused
from django.contrib.auth.models import User
from django.conf import settings

from .models import Operation, OperationMessage, Queue


logger = logging.getLogger(__name__)
suppress_unused()


@add_metaclass(ABCMeta)
class WorkerQueue(object):
    queue_name = None
    worker_class = None
    MAX_UPDATE_TIMEOUT = 300
    MAX_OPERATION_WAIT = 30
    DEAD_LOOP_DELAY = 300
    STATUS_LOOP_DELAY = 60
    MAX_RECONNECT_COUNT = 10
    MAX_PQ_RECONNECT_COUNT = 10

    def __init__(self, max_retries=0, retry_in='10s', schedule_at=None, expected_at=None):
        if self.queue_name is None:
            raise AttributeError("Can't create Queue class without property 'queue_namea'")
        if self.worker_class is None:
            raise AttributeError("Can't create Queue class without property 'worker_class'")

        self._connect_kwargs = {
            'dbname': settings.DATABASES['default']['NAME'],
            'user': settings.DATABASES['default']['USER'],
            'host': settings.DATABASES['default']['HOST'],
            'port': settings.DATABASES['default']['PORT'],
            'password': settings.DATABASES['default']['PASSWORD'],
            'target_session_attrs': 'read-write'
        }

        self._conn = connect(**self._connect_kwargs)
        self._reconnect_count = 0

        self._pq = PQ(connect(**self._connect_kwargs))
        self._queue = self._pq[self.queue_name]
        self._pq_reconnect_count = 0

        self._max_retries = max_retries
        self._retry_in = retry_in
        self._schedule_at = schedule_at
        self._expected_at = expected_at

        self.__name__ = '{}.{}'.format(self.__module__, type(self).__name__)

        self._queue.handler_registry[self.__name__] = self._call

        self._loops = {}
        self._worker_threads = []

        self._terminating = Value('b', False)
        self._pool = None

    _workers = dict()

    def kill(self):
        self._terminating.value = True

    @classmethod
    def _get_queue_type(cls):
        return re.findall('[A-Z][^A-Z]*', cls.__name__)[0].lower()

    @classmethod
    def __init_subclass__(cls, **kwargs):
        WorkerQueue._workers[cls._get_queue_type()] = cls

    @staticmethod
    def get_queue(name):
        return WorkerQueue._workers.get(name)

    def _call(self, job_id, *args, **kwargs):
        """
        Private method for setting job status and calling worker function

        :param job_id: Id of job from PQ
        :param args: Arguments for function
        :param kwargs: Keyword arguments for function
        :return:
        """
        logger.info('Waiting for operation {} to become available in db'.format(job_id))
        max_time = time.time() + self.MAX_OPERATION_WAIT

        while time.time() < max_time:
            try:
                op = Operation.objects.get(queue=job_id)
                break
            except Operation.DoesNotExist:
                logger.error('Waiting for operation {}'.format(job_id))
                time.sleep(1)

        if time.time() >= max_time:
            logger.error('No operation for job: {}'.format(job_id))
            return

        logger.info('Starting job {}'.format(job_id))

        self._start_status_loop(job_id)

        passed = False
        try:
            worker = self.worker_class(op)
            worker.set_progress()
            res = worker.run(*args, **kwargs)
            passed = True
        except Exception as exception:
            logger.exception("Got exception, while runnig job for operation %s", job_id)
            type_, value_, traceback_ = sys.exc_info()
            tb = ''.join(traceback.format_tb(traceback_))
            worker.set_error(exception, tb)
        else:
            logger.info('Setting result: %s for operation %s', res, job_id)
            worker.set_result(res)

        self._end_status_loop(job_id)
        if passed:
            logger.info('Job {} finished successfuly'.format(job_id))
        else:
            logger.warning('Job {} failed, rescheduling'.format(job_id))
            self._reschedule(job_id)

        self._pq_reconnect_count = 0

    def _reschedule(self, job_id):
        """
        Same logic as in pq but we handle new job_id
        :param job_id: Id of job from PQ
        :return:
        """
        logger.debug('Trying to reschedule job {}'.format(job_id))
        q = Queue.objects.get(id=job_id)
        retried = q.data['retried']
        if q.data.get('max_retries', 0) > retried:
            logger.info('Job {} will be recsheduled'.format(job_id))
            q.data['retried'] += 1
            new_id = self._queue.put(q.data, schedule_at=q.data['retry_in'])
            op = Operation.objects.get(queue=job_id)
            op.queue = Queue.objects.get(id=new_id)
            op.operation_status = Operation.STATUS_RESCHEDULED
            op.save()
            logger.info('Job {} rescheduled'.format(job_id))

    def create(self):
        """
        Creating table for PQ
        :return:
        """
        logger.debug('Trying to create pq table')
        with suppress(psycopg2.errors.DuplicateTable):
            self._pq.create()

    def _status_loop(self, job_id):
        """
        Loop for pushing alive status IN PROGRESS to operations table
        :param job_id: Id of job from PQ
        :return:
        """
        while job_id in self._loops:
            try:
                logger.debug("Refreshing 'updated_at' for job {}".format(job_id))
                Operation.objects.get(queue=job_id).save()
            except Exception as e:
                logger.exception('Unknown error: %s', e)
            time.sleep(self.STATUS_LOOP_DELAY)

    def _start_status_loop(self, job_id):
        """
        Start status loop for job
        :param job_id: Id of job from PQ
        :return:
        """
        self._loops[job_id] = True
        ThreadPool(1).apply_async(self._status_loop, (job_id, ))

    def _end_status_loop(self, job_id):
        """
        Finish status loop for job
        :param job_id: Id of job from PQ
        :return:
        """
        self._loops.pop(job_id, None)

    def check_state(self):
        """
        Check state of worker
        :return: True if all async loops is running else False
        """
        try:
            for ind, th_result in enumerate(self._worker_threads):
                if th_result.ready():
                    try:
                        result = th_result.get()
                    except Exception as e:
                        logger.exception('Thread finished with exception: %s', e)
                        result = None

                    logger.error('Thread %s is ready. Result: %s', th_result, result)

                    self._worker_threads[ind] = self._pool.apply_async(self.loop)

                    return False
            return True
        except Exception as e:
            logger.exception('Unknown error in check_state: %s', e)
            return False

    @abstractmethod
    def validate(self, *args, **kwargs):
        pass

    def async_loop(self, worker_count=cpu_count()):
        """
        Start worker threads in background
        :param worker_count: Count of threads for worker
        :return: List of thread results
        """
        self._pool = ThreadPool(worker_count)
        self._worker_threads = [self._pool.apply_async(self.loop) for _ in range(worker_count)]
        return self._worker_threads

    def loop(self):
        """
        Loop for polling queue
        :return:
        """
        logger.info('Start loop for worker queue %s', self)
        while True:
            try:
                if self._terminating.value:
                    break
                if self._queue is None:
                    raise RuntimeError('Queue is None, reconnecting...')
                self._queue.work()
            except Exception as e:
                logger.exception('Unknown error: %s', e)
                logger.info('Recreate connection to PQ')
                self._pq_reconnect_count += 1

                try:
                    self._pq = PQ(connect(**self._connect_kwargs))
                    self._queue = self._pq[self.queue_name]
                except Exception as e:
                    logger.exception('Unknown error: %s', e)
                    self._queue = None

                if self._pq_reconnect_count > WorkerQueue.MAX_PQ_RECONNECT_COUNT:
                    logger.error('Too mane attempts to connect to PQ')

        logger.info('Loop ended for worker queue %s', self)

    def get(self):
        """
        Get single queue element
        :return: Queue element
        """
        return self._queue.get()

    def async_execute(self):
        """
        Execute async single task from queue
        :return: Thread result
        """
        return ThreadPool(1).apply_async(self.execute, ())

    def execute(self):
        """
        Execute single task from queue
        :return:
        """
        try:
            self._queue.work(True)
            self._pq_reconnect_count = 0
        except Exception as e:
            logger.exception('Unknown exception: %s', e)
            logger.info('Recreate connection to PQ')
            self._pq = PQ(connect(**self._connect_kwargs))
            self._queue = self._pq[self.queue_name]
            self._pq_reconnect_count += 1

            if self._pq_reconnect_count > WorkerQueue.MAX_PQ_RECONNECT_COUNT:
                logger.error('Too mane attempts to connect to PQ')

    def push(self, username, *args, **kwargs):
        """
        Add task to queue
        :param username: User id of owner of task
        :param args: Arguments for task
        :param kwargs: Keyword arguments for task
        :return: Job id from pq
        """
        op_type = self._get_queue_type()
        if op_type not in Operation.ALL_TYPES:
            op_type = Operation.TYPE_OTHER

        deploy_ipv6 = get_deploy_box_ipv6()
        if deploy_ipv6 is not None:
            logger.info('Adding %s to queue. Use `ssh -6 %s` for connecting to master and view logs', op_type, deploy_ipv6)

        op = Operation.objects.create(
            operation_status=Operation.STATUS_CREATED,
            master_host=get_deploy_node_fqdn() or socket.getfqdn(),
            master_fqdn=get_deploy_pod_fqdn() or socket.getfqdn(),
            operation_type=op_type
        )

        job_id = self._queue.put(
            dict(
                function=self.__name__,
                args=args,
                kwargs=kwargs,
                retried=0,
                retry_in=self._retry_in,
                max_retries=self._max_retries,
            ), schedule_at=self._schedule_at, expected_at=self._expected_at
        )

        op.queue = Queue.objects.get(id=job_id)
        op.operation_status = Operation.STATUS_SCHEDULED
        if username is not None:
            op.user = User.objects.get(username=username)
        op.save()

        return op

    def async_dead_jobs_loops(self):
        """
        Run find_dead_jobs in background loop
        :return: Thread result
        """
        return ThreadPool(1).apply_async(self.dead_jobs_loops, ())

    def dead_jobs_loops(self):
        """
        Run find_dead_jobs in loop
        :return:
        """
        while True:
            if self._terminating.value:
                break
            self.find_dead_jobs()
            time.sleep(self.DEAD_LOOP_DELAY)

    def find_dead_jobs(self):
        """
        Reshedule all dead jobs (In scheduled/rescheduled/in process time for long time without updates)
        :return:
        """
        try:
            self._find_dead_jobs()
            self._reconnect_count = 0
        except Exception as e:
            logger.exception('Unknown error: %s', e)
            logger.info('Create new connection to database for sure')
            self._conn = connect(**self._connect_kwargs)
            django_connect = django_db.connections['default']
            django_connect.connect()
            self._reconnect_count += 1

        if self._reconnect_count > WorkerQueue.MAX_RECONNECT_COUNT:
            logger.error('Too many failed attempts')
            # TODO may be restart worker?

    def _find_dead_jobs(self):
        # TODO: Move under transaction & mark with select_for_update
        time_threshold = datetime.datetime.now() - datetime.timedelta(seconds=self.MAX_UPDATE_TIMEOUT)
        dead_operations = Operation.objects.filter(updated_at__lt=time_threshold,
                                                   operation_status__in=[Operation.STATUS_IN_PROGRESS,
                                                                         Operation.STATUS_SCHEDULED,
                                                                         Operation.STATUS_RESCHEDULED])
        for op in dead_operations:
            sql = "UPDATE queue SET dequeued_at = NULL WHERE id=%s"
            rowcount = 0
            with self._conn:
                with self._conn.cursor() as cursor:
                    cursor.execute(sql, (op.queue.id, ))
                    rowcount = cursor.rowcount
            if rowcount:
                logger.info('Setting Operation.STATUS_RESCHEDULED for {}: job_id{}'.format(op.id, op.queue.id))
                op.operation_status = Operation.STATUS_RESCHEDULED
                op.save()


@add_metaclass(ABCMeta)
class Worker(object):
    def __init__(self, op):
        self.operation = op
        self._worker_uid = str(uuid4())

        deploy_ipv6 = get_deploy_box_ipv6()
        if deploy_ipv6 is not None:
            logger.info('Executing %s (job_id %s). Use `ssh -6 root@%s` for connecting to worker and view logs', self.__class__.__name__, op.queue_id, deploy_ipv6)
        logger.debug('Setting status for job {} to started on worker {} on host {}'.format(op.queue_id, self._worker_uid, socket.getfqdn()))

        op.operation_status = Operation.STATUS_STARTED
        op.started_at = datetime.datetime.now()
        op.worker = self._worker_uid

        op.worker_host = get_deploy_node_fqdn() or socket.getfqdn()
        op.worker_fqdn = get_deploy_pod_fqdn() or socket.getfqdn()

        op.save()

    @abstractmethod
    def run(self, *args, **kwargs):
        """
        Abstract method for calling in workers
        :param args: Method args
        :param kwargs: Method kwargs
        :return:
        """
        raise NotImplementedError('Subclasses should implement run()')

    def set_error(self, exception=None, tb=None):
        """
        set error in operations table for job
        :param exception: exception object
        :param tb: traceback of error
        :return:
        """
        self.operation.set_error(self.get_last_error(), exception, tb)

    def set_result(self, result):
        """
        Mark operation as 'finished' and set result
        :param result: results of operation
        :return:
        """
        self.operation.set_result(result)

    def set_progress(self):
        """
        Mark operation as 'in_progress'.
        :return:
        """
        # TODO: allow progress tracking
        self.operation.set_progress()

    @abstractmethod
    def get_last_error(self):
        """
        Get last error in worker
        :return: Last error or None
        """
        return None

    def add_message(self, message):
        """
        Add OperationMessage to current operation.
        Note: This method currently interacts directly with django db objects
        :param message: message to be recorded in db
        :return:
        """
        # TODO: Allow this method to work in async manner

        caller_frame = inspect.stack()[1]
        caller = inspect.getframeinfo(caller_frame[0])
        caller_module = inspect.getmodule(caller_frame[0])
        OperationMessage(
            operation=self.operation,
            message=message,
            module_name=caller_module.__name__,
            lineno=caller.lineno
        ).save()


__all__ = ['WorkerQueue', 'Worker']
