import time

from gevent import threadpool

from infra.orly.proto import orly_pb2


class RuleCodec(object):
    def encode(self, r):
        return r.SerializeToString()

    def decode(self, data, rev):
        r = orly_pb2.Rule()
        r.MergeFromString(data)
        r.meta.revision = rev
        return r


class OperationCodec(object):
    @staticmethod
    def encode(o):
        return o.SerializeToString()

    @staticmethod
    def decode(data, rev):
        op = orly_pb2.Operation()
        op.MergeFromString(data)
        op.meta.revision = rev
        return op


class Storage(object):
    """
    We store:
        /ops/{rule}/{object id}  # active operations per rule
        ...
        /rules/{rule}  # rule definition and status
        ...
    """
    # etcd's default limit DefaultMaxTxnOps = 128
    GC_BATCH_SIZE = 100
    # Sentinel error string to distinguish error conditions
    ERR_OP_IN_PROGRESS = 'operation already in progress'

    def __init__(self, etcd):
        """
        :type etcd: etcd3.Etcd3Client
        """
        self.etcd = etcd
        self.rule_codec = RuleCodec()
        self.op_codec = OperationCodec()

    def put_rule(self, rule, timeout=None):
        """
        Creates or updates rule.
        """
        path = '/rules/' + rule.meta.id
        old = self.get_rule(rule.meta.id, timeout=timeout)
        for _ in range(10):
            if old is None:
                rule.meta.version = 1
                buf = self.rule_codec.encode(rule)
                status, responses = self.etcd.transaction(
                    compare=[
                        self.etcd.transactions.version(path) < 1
                    ],
                    success=[
                        self.etcd.transactions.put(path, buf)
                    ],
                    failure=[
                        self.etcd.transactions.get(path)
                    ],
                )
                if status:
                    return rule, None
                buf, meta = responses[0][0]
                old = self.rule_codec.decode(buf, meta.mod_revision)
            else:  # Rule exists, we must overwrite it, preserving status
                rule.meta.version = old.meta.version + 1
                rule.status.CopyFrom(old.status)
                buf = self.rule_codec.encode(rule)
                status, responses = self.etcd.transaction(
                    compare=[
                        self.etcd.transactions.mod(path) == old.meta.revision
                    ],
                    success=[
                        self.etcd.transactions.put(path, buf)
                    ],
                    failure=[
                        self.etcd.transactions.get(path)
                    ]
                )
                if status:
                    return rule, None
                else:
                    buf, meta = responses[0][0]
                    old = self.rule_codec.decode(buf, meta.mod_revision)
        # If we got here - we failed to update rule too many times
        return None, 'Concurrency failure - failed to update rule'

    def get_rule(self, rule_id, timeout=None):
        if not rule_id:
            raise ValueError('No rule_id specified')
        b, meta = self.etcd.get('/rules/' + rule_id)
        if b is None:
            return None
        return self.rule_codec.decode(b, meta.mod_revision)

    def delete_rule(self, rule_id, timeout=None):
        if not rule_id:
            raise ValueError('No rule_id specified')
        status, _ = self.etcd.transaction(
            compare=[],
            success=[
                self.etcd.transactions.delete('/rules/{}'.format(rule_id)),
                self.etcd.transactions.delete('/ops/' + rule_id + '/'),
            ],
            timeout=timeout,
        )
        if not status:
            return False
        return True

    def list_rules(self, timeout=None):
        rv = []
        for b, meta in self.etcd.get_prefix('/rules/'):
            rv.append(self.rule_codec.decode(b, meta.mod_revision))
        return rv

    def list_operations(self, rule_id, timeout=None):
        rv = []
        for b, meta in self.etcd.get_prefix('/ops/' + rule_id + '/'):
            rv.append(self.op_codec.decode(b, meta.mod_revision))
        return rv

    def collect_stale_operations(self, get_policy):
        # Check if some operations timed out
        ok = 0
        failed = 0
        for rule in self.list_rules():
            now = time.time()
            op_ids = []
            p = get_policy(rule)
            ops = self.list_operations(rule.meta.id)
            ops.sort(key=lambda o: o.status.in_progress.last_transition_time.seconds)
            for op in ops:
                in_progress = op.status.in_progress
                if in_progress.status != 'True':
                    continue
                if in_progress.last_transition_time.ToSeconds() + rule.spec.duration.ToSeconds() < now:
                    op_ids.append(op.meta.id)
                    p.done(rule, op)
                    if len(op_ids) >= self.GC_BATCH_SIZE:
                        break
            if self.finish_batch(rule, op_ids):
                ok += len(op_ids)
            else:
                failed += len(op_ids)
        return ok, failed

    def finish_batch(self, rule, op_ids):
        rule_data = self.rule_codec.encode(rule)
        rule_path = '/rules/' + rule.meta.id
        compare = [
            # Check that rule revision is the same
            self.etcd.transactions.mod(rule_path) == rule.meta.revision,
        ]
        success = [
            # Update rule
            self.etcd.transactions.put(rule_path, rule_data),
        ]
        for op_id in op_ids:
            # Check that operation exists for this object
            op_path = '/ops/{}/{}'.format(rule.meta.id, op_id)
            compare.append(self.etcd.transactions.version(op_path) > 0, )
            success.append(self.etcd.transactions.delete(op_path))
        status, _ = self.etcd.transaction(compare=compare, success=success, failure=[])
        return status

    def start_operation(self, op, rule, timeout=None):
        rule_data = self.rule_codec.encode(rule)
        op_data = self.op_codec.encode(op)
        # Now let's try to add more thing
        rule_path = '/rules/' + rule.meta.id
        op_path = '/ops/{}/{}'.format(rule.meta.id, op.meta.id)
        status, responses = self.etcd.transaction(
            compare=[
                # Check that rule revision is the same
                self.etcd.transactions.mod(rule_path) == rule.meta.revision,
                # Check that operation does not exists for this object
                self.etcd.transactions.version(op_path) < 1,
            ],
            success=[
                self.etcd.transactions.put(rule_path, rule_data),
                self.etcd.transactions.put(op_path, op_data),
            ],
            failure=[
                self.etcd.transactions.get(op_path),
            ],
        )
        # There are several cases:
        #   * status is True, we've correctly updated rule and put operation in place
        # bad status (is False):
        #   * rule does not exists (removed for some reason)
        #   return 404 (but we must issue get in failure get)
        #   * our rule version is outdated (someone updated it just before us)
        #   try again (to be done)
        #   * our operation already exists
        #   return some conflict response
        if not status:
            if responses and responses[0]:
                return self.ERR_OP_IN_PROGRESS
            return 'transaction conflict'
        return None


class GeventFriendlyStorage(object):
    def __init__(self, storage: Storage, pool_size=3):
        self.storage = storage
        self.pool = threadpool.ThreadPool(pool_size)

    def put_rule(self, rule, timeout=None):
        return self.pool.apply(self.storage.put_rule, args=(rule,), kwds={'timeout': timeout})

    def get_rule(self, rule_id, timeout=None):
        return self.pool.apply(self.storage.get_rule, args=(rule_id,), kwds={'timeout': timeout})

    def delete_rule(self, rule_id, timeout=None):
        return self.pool.apply(self.storage.delete_rule, args=(rule_id,), kwds={'timeout': timeout})

    def list_rules(self, timeout=None):
        return self.pool.apply(self.storage.list_rules, kwds={'timeout': timeout})

    def list_operations(self, rule_id, timeout=None):
        return self.pool.apply(self.storage.list_operations, args=(rule_id,), kwds={'timeout': timeout})

    def start_operation(self, op, rule, timeout=None):
        return self.pool.apply(self.storage.start_operation, args=(op, rule), kwds={'timeout': timeout})
