# mpatch.py - Python implementation of mpatch.c
#
# Copyright 2009 Matt Mackall <mpm@selenic.com> and others
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2 or any later version.

import struct
import itertools

try:
    from io import StringIO
except ImportError:
    from io import StringIO

from mercurial import manifest, context, node


class mpatchError(Exception):
    """error raised when a delta cannot be decoded
    """

# This attempts to apply a series of patches in time proportional to
# the total size of the patches, rather than patches * len(text). This
# means rather than shuffling strings around, we shuffle around
# pointers to fragments with fragment lists.
#
# When the fragment lists get too long, we collapse them. To do this
# efficiently, we do all our operations inside a buffer created by
# mmap and simply use memmove. This avoids creating a bunch of large
# temporary string buffers.


def _pull(dst, src, l):  # pull l bytes from src
    while l:
        f = src.pop()
        if f[0] > l:  # do we need to split?
            src.append((f[0] - l, f[1] + l))
            dst.append((l, f[1]))
            return
        dst.append(f)
        l -= f[0]


def _move(m, dest, src, count):
    """move count bytes from src to dest

    The file pointer is left at the end of dest.
    """
    m.seek(src)
    buf = m.read(count)
    m.seek(dest)
    m.write(buf)


def _collect(m, buf, mod_list):
    start = buf
    for l, p in reversed(mod_list):
        _move(m, buf, p, l)
        buf += l
    return (buf - start, start)


def _get_changes(text):
    result = (row.split('\0')[:2] for row in text.split('\n') if row)
    return result


def patches(textio, frags, bins):
    modified, added, removed = {}, {}, {}

    if not bins:
        return frags, modified, added, removed

    plens = [len(x) for x in bins]
    pl = sum(plens)
    bl = frags[0][0] + pl
    tl = bl + bl + pl  # enough for the patches and two working texts
    b1, b2 = 0, bl

    if not tl:
        return frags, modified, added, removed

    # copy all the patches into our segment so we can memmove from them
    pos = b2 + bl
    textio.seek(pos)
    for p in bins:
        textio.write(p)

    for plen in plens:
        # if our list gets too long, execute it
        if len(frags) > 128:
            b2, b1 = b1, b2
            frags = [_collect(textio, b1, frags)]

        new = []
        end = pos + plen
        last = 0

        while pos < end:
            textio.seek(pos)
            binm = textio.read(12)
            try:
                p1, p2, l = struct.unpack(">lll", binm)
            except struct.error:
                raise mpatchError("patch cannot be decoded")

            textio.seek(p1)
            for file_name, file_hash in _get_changes(textio.read(p2 - p1)):
                if file_name in added:
                    del added[file_name]
                if file_name in modified:
                    removed[file_name] = (modified.pop(file_name)[0], None)
                if file_name not in removed:
                    removed[file_name] = (file_hash, None)

            textio.seek(pos + 12)
            for file_name, file_hash in _get_changes(textio.read(l)):
                if file_name in removed:
                    modified[file_name] = (removed.pop(file_name)[0], file_hash)
                elif file_name in modified:
                    modified[file_name] = (modified[file_name][0], file_hash)
                elif file_name not in added:
                    added[file_name] = (None, file_hash)

            _pull(new, frags, p1 - last)  # what didn't change
            _pull([], frags, p2 - p1)     # what got deleted
            new.append((l, pos + 12))     # what got added
            pos += l + 12
            last = p2

        frags.extend(reversed(new))     # what was left at the end

    frags = [_collect(textio, b2, frags)]

    return frags, modified, added, removed


class DiffManifests(object):

    def __init__(self, repo):
        self.repo = repo
        self.textio = None
        self.frags = []
        self.actual_rev = None

    def init_textio(self, rev, text):
        self.actual_rev = rev
        text_len = len(text)
        self.frags = [(text_len, 0)]
        self.textio = self.textio or StringIO()
        self.textio.seek(0)
        self.textio.write(text)
        self.textio.truncate(text_len)

    def get(self, parent_ctx, current_ctx):
        parent_rev = self.repo.manifest.rev(parent_ctx._changeset.manifest)
        current_rev = self.repo.manifest.rev(current_ctx._changeset.manifest)

        if self.actual_rev != parent_rev:
            self.init_textio(parent_rev, self.repo.manifest.revision(parent_rev))

        deltachain = self.repo.manifest._deltachain(current_rev, stoprev=parent_rev)[0]
        bins = self.repo.manifest._chunks(deltachain)

        frags, modified, added, removed = patches(self.textio, self.frags, bins)

        self.textio.seek(frags[0][1])
        text = self.textio.read(frags[0][0])
        self.init_textio(current_rev, text)

        pm_text = '\n'.join(sorted(
            '%s\0%s' % (k, v[0])
            for k, v in itertools.chain(
                iter(removed.items()), iter(modified.items())
            )
        ))

        cm_text = '\n'.join(sorted(
            '%s\0%s' % (k, v[1])
            for k, v in itertools.chain(
                iter(modified.items()), iter(added.items())
            )
        ))

        pm_text = pm_text + '\n' if pm_text else ''
        cm_text = cm_text + '\n' if cm_text else ''

        parent_manifest = manifest.manifestdict(pm_text)
        current_manifest = manifest.manifestdict(cm_text)

        return parent_manifest, current_manifest


def get_node_rev_by_changeid(changeid, repo):
    node = rev = None

    if isinstance(changeid, int):
        node = repo.changelog.node(changeid)
        rev = changeid
    elif changeid == 'null':
        node = node.nullid
        rev = node.nullrev
    elif changeid == 'tip':
        node = repo.changelog.tip()
        rev = repo.changelog.rev(node)
    elif changeid == '.':
        # this is a hack to delay/avoid loading obsmarkers
        # when we know that '.' won't be hidden
        node = repo.dirstate.p1()
        rev = repo.unfiltered().changelog.rev(node)

    elif len(changeid) == 20:
        node = changeid
        rev = repo.changelog.rev(changeid)

    elif len(changeid) == 40:
        try:
            node = bin(changeid)
            rev = repo.changelog.rev(node)
        except (TypeError, LookupError):
            pass

    try:
        node = repo.names.singlenode(repo, changeid)
        rev = repo.changelog.rev(node)
        return node, rev
    except KeyError:
        pass
    if node is None:
        node = repo.unfiltered().changelog._partialmatch(changeid)
        if node is not None:
            rev = repo.changelog.rev(node)

    return node, rev
