# -*- coding: utf-8 -*-
"""
На основе кода из https://github.com/noplay/python-mysql-replication/blob/master/pymysqlreplication/gtid.py
@12fa0036a5b7120ac2a1b313f3cfda0b0388151c
Код переработан для поддержки python3 и сортировок
"""

import re
from typing import Iterable


def overlap(i1, i2):
    return i1[0] < i2[1] and i1[1] > i2[0]


def contains(i1, i2):
    return i2[0] >= i1[0] and i2[1] <= i1[1]


class Gtid:
    """A mysql GTID is composed of a server-id and a set of right-open
    intervals [a,b), and represent all transactions x that happened on
    server UUID such as

        x <= a < b

    The human representation of it, though, is either represented by a
    single transaction number A=a (when only one transaction is covered,
    ie b = a+1)

        UUID:A

    Or a closed interval [A,B] for at least two transactions (note, in that
    case, that b=B+1)

        UUID:A-B

    We can also have a mix of ranges for a given UUID:
        UUID:1-2:4:6-74

    For convenience, a Gtid accepts adding Gtid's to it and will merge
    the existing interval representation. Adding TXN 3 to the human
    representation above would produce:

        UUID:1-4:6-74

    and adding 5 to this new result:

        UUID:1-74

    Adding an already present transaction number (one that overlaps) will
    raise an exception.

    Adding a Gtid with a different UUID will raise an exception.
    """
    @staticmethod
    def parse_interval(interval):
        """
        We parse a human-generated string here. So our end value b
        is incremented to conform to the internal representation format.
        """
        m = re.search("^([0-9]+)(?:-([0-9]+))?$", interval)
        if not m:
            raise ValueError("GTID format is incorrect: %r" % (interval, ))
        a = int(m.group(1))
        b = int(m.group(2) or a)
        return a, b+1

    @staticmethod
    def parse(gtid):
        m = re.search("^([0-9a-fA-F]{8}(?:-[0-9a-fA-F]{4}){3}-[0-9a-fA-F]{12})"
                      "((?::[0-9-]+)+)$", gtid)
        if not m:
            raise ValueError("GTID format is incorrect: %r" % (gtid, ))

        uuid = m.group(1)
        intervals = m.group(2)

        intervals_parsed = [Gtid.parse_interval(x)
                            for x in intervals.split(":")[1:]]

        return uuid, intervals_parsed

    def __add_interval(self, itvl):
        """
        Use the internal representation format and add it
        to our intervals, merging if required.
        """
        new = []

        if itvl[0] > itvl[1]:
            raise Exception("Malformed interval %s" % (itvl,))

        if any(overlap(x, itvl) for x in self.intervals):
            raise Exception("Overlapping interval %s" % (itvl,))

        # Merge: arrange interval to fit existing set
        for existing in sorted(self.intervals):
            if itvl[0] == existing[1]:
                itvl = (existing[0], itvl[1])
                continue

            if itvl[1] == existing[0]:
                itvl = (itvl[0], existing[1])
                continue

            new.append(existing)

        self.intervals = sorted(new + [itvl])

    def __sub_interval(self, itvl):
        """Using the internal representation, remove an interval"""
        new = []

        if itvl[0] > itvl[1]:
            raise Exception("Malformed interval %s" % (itvl,))

        if not any(overlap(x, itvl) for x in self.intervals):
            # No raise
            return

        # Merge: arrange existing set around interval
        for existing in sorted(self.intervals):
            if overlap(existing, itvl):
                if existing[0] < itvl[0]:
                    new.append((existing[0], itvl[0]))
                if existing[1] > itvl[1]:
                    new.append((itvl[1], existing[1]))
            else:
                new.append(existing)

        self.intervals = new

    def __contains__(self, other):
        if other.uuid != self.uuid:
            return False

        return all(any(contains(me, them) for me in self.intervals)
                   for them in other.intervals)

    def __init__(self, gtid, uuid=None, intervals=None):
        if not (uuid and intervals):
            uuid, intervals = Gtid.parse(gtid)

        self.uuid = uuid
        self.intervals = []
        for itvl in intervals:
            self.__add_interval(itvl)

    def __add__(self, other):
        """Include the transactions of this gtid. Raise if the
        attempted merge has different UUID"""
        if self.uuid != other.uuid:
            raise Exception("Attempt to merge different UUID"
                            "%s != %s" % (self.uuid, other.uuid))

        result = Gtid(str(self))

        for itvl in other.intervals:
            result.__add_interval(itvl)

        return result

    def __sub__(self, other):
        """Remove intervals. Do not raise, if different UUID simply
        ignore"""
        result = Gtid(str(self))
        if self.uuid != other.uuid:
            return result

        for itvl in other.intervals:
            result.__sub_interval(itvl)

        return result

    def __hash__(self):
        return hash(str(self))

    def is_clone_of(self, other):
        return self.uuid == other.uuid and sorted(self.intervals) == sorted(other.intervals)

    def __eq__(self, other):
        return self in other and other in self

    def __ne__(self, other):
        return not (self == other)

    def __lt__(self, other):
        if self.uuid != other.uuid:
            return self.uuid < other.uuid
        return self in other and self != other

    def __le__(self, other):
        if self.uuid != other.uuid:
            return self.uuid < other.uuid
        return self in other

    def __gt__(self, other):
        if self.uuid != other.uuid:
            return self.uuid > other.uuid
        return other in self and self != other

    def __ge__(self, other):
        if self.uuid != other.uuid:
            return self.uuid > other.uuid
        return other in self

    def __str__(self):
        """We represent the human value here - a single number
        for one transaction, or a closed interval (decrementing b)"""
        return "%s:%s" % (self.uuid,
                          ":".join(("%d-%d" % (x[0], x[1]-1)) if x[0] + 1 != x[1]
                                   else str(x[0])
                                   for x in self.intervals))

    def __repr__(self):
        return "<Gtid '%s'>" % self


class GtidSet:
    def __init__(self, gtid_set, keep_sorted=False):
        def _to_gtid(element):
            if isinstance(element, Gtid):
                return element
            return Gtid(element.strip(" \n"))

        self.keep_sorted = keep_sorted

        if not gtid_set:
            self.gtids = []
        elif isinstance(gtid_set, Iterable) and not isinstance(gtid_set, str):
            self.gtids = [_to_gtid(x) for x in gtid_set]
        elif isinstance(gtid_set, str):
            self.gtids = [Gtid(x.strip(" \n")) for x in gtid_set.split(",")]
        else:
            raise ValueError("GTID Set format is incorrect: %r" % (gtid_set, ))

        if self.keep_sorted:
            self.gtids.sort()
        # Мерджить изначально указанный gtid-set не нужно! Как mysql отдал, так и будет

    def merge_gtid(self, gtid):
        new_gtids = []
        for existing in self.gtids:
            if existing.uuid == gtid.uuid:
                new_gtids.append(existing + gtid)
            else:
                new_gtids.append(existing)
        if gtid.uuid not in (x.uuid for x in new_gtids):
            new_gtids.append(gtid)

        self.gtids = sorted(new_gtids) if self.keep_sorted else new_gtids

    def __contains__(self, other):
        if isinstance(other, Gtid):
            return any(other in x for x in self.gtids)
        elif isinstance(other, GtidSet):
            return all(any(them in me for me in self.gtids)
                       for them in other.gtids)

        raise NotImplementedError

    def __add__(self, other):
        if isinstance(other, Gtid):
            new = GtidSet(self.gtids)
            new.merge_gtid(other)
            return new
        raise NotImplementedError

    def __iter__(self):
        return iter(self.gtids)

    def __getitem__(self, item):
        return self.gtids[item]

    def __hash__(self):
        return hash(str(self))

    def is_clone_of(self, other):
        return len(self.gtids) == len(other.gtids) and \
            all(x.is_clone_of(y) for x, y in zip(sorted(self.gtids), sorted(other.gtids)))

    def __eq__(self, other):
        return self in other and other in self

    def __ne__(self, other):
        return not (self == other)

    def __lt__(self, other):
        # по длинам gtids сравнивать нельзя!
        return self in other and self != other

    def __le__(self, other):
        return self in other

    def __gt__(self, other):
        return other in self and self != other

    def __ge__(self, other):
        return other in self

    def __str__(self):
        return ",".join(str(x) for x in self.gtids)

    def __repr__(self):
        return "<GtidSet %r>" % self.gtids
