#!/usr/bin/python3
# -*- coding: utf8 -*-

import subprocess
import os
import sys
import difflib
import json
import shutil
import zipfile
import tempfile
import re
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from pathlib import Path

FERNFLOWER_JAR = '/usr/lib/fernflower/fernflower.jar'

JAVA_DIR = '/usr/local/yandex-direct-jdk11/bin/'
JAVA = JAVA_DIR + 'java'
JAVAC = JAVA_DIR + 'javac'
JAVAP = JAVA_DIR + 'javap'

GET_VERSION_TOOL_CLASS = "ru.yandex.direct.version.GetDirectVersionTool"
WRITE_CLASS_TO_STDOUT_TOOL_CLASS = "ru.yandex.direct.common.util.WriteClassToStdoutTool"

APPS = {
    'intapi': {
        'classpath-dir': '/var/www/direct-intapi/direct-intapi',
        'working-dir': '/var/www/direct-intapi/hotfix'
    },
    'web': {
        'classpath-dir': '/var/www/direct-web/direct-web',
        'working-dir': '/var/www/direct-web/hotfix'
    },
    'api5': {
        'classpath-dir': '/var/www/direct-api5/direct-api5',
        'working-dir': '/var/www/direct-api5/hotfix'
    },
    'jobs': {
        'classpath-dir': '/var/www/direct-jobs/lib',
        'working-dir': '/var/www/direct-jobs/hotfix'
    },
    'logviewer': {
        'classpath-dir': '/var/www/direct-logviewer/direct-logviewer',
        'working-dir': '/var/www/direct-logviewer/hotfix'
    },
    'binlogbroker': {
        'classpath-dir': '/var/www/direct-binlog-logbrokerwriter/lib',
        'working-dir': '/var/www/direct-binlog-logbrokerwriter/hotfix'
    },
    'ess-router': {
        'classpath-dir': '/var/www/direct-ess-router/lib',
        'working-dir': '/var/www/direct-ess-router/hotfix'
    },
    'binlog2yt': {
        'classpath-dir': '/usr/lib/yandex-direct-binlog-to-yt-sync/mysql-direct-yt-sync',
        'working-dir': '/usr/lib/yandex-direct-binlog-to-yt-sync/hotfix'
    }
}


def debug(*args, **kwargs):
    kwargs['file'] = sys.stderr
    print(*args, **kwargs)


# inspired by https://stackoverflow.com/a/49912639
class DirectoryTree:
    display_filename_prefix_middle = '├──'
    display_filename_prefix_last = '└──'
    display_parent_prefix_middle = '    '
    display_parent_prefix_last = '│   '

    def __init__(self, path, parent, is_last):
        self.path = path
        self.parent = parent
        self.is_last = is_last

    def display(self):
        display_name = self.path.relative_to(self.parent.path) if self.parent else self.path
        display_name = str(display_name)
        if self.path.is_dir():
            display_name += '/'

        if self.parent is None:
            return display_name

        prefix = (self.display_filename_prefix_last
                  if self.is_last
                  else self.display_filename_prefix_middle)
        parts = ['{} {}'.format(prefix, display_name)]

        parent = self.parent
        while parent and parent.parent is not None:
            parts.append(self.display_parent_prefix_middle
                         if parent.is_last
                         else self.display_parent_prefix_last)
            parent = parent.parent

        return ''.join(reversed(parts))

    @classmethod
    def make_tree(cls, path):
        return cls._make_tree(path, None, True)

    @classmethod
    def _make_tree(cls, path, parent, is_last):
        children = sorted(list(path.iterdir()), key=lambda s: str(s).lower())

        if len(children) == 1 and children[0].is_dir() and parent is not None:
            yield from cls._make_tree(children[0], parent, is_last)
        else:
            root = cls(path, parent, is_last)
            yield root

            dirs = [dir for dir in children if dir.is_dir()]
            files = [file for file in children if not file.is_dir()]

            count = 1

            def is_last():
                return count == len(children)

            for dir in dirs:
                yield from cls._make_tree(dir, root, is_last())
                count += 1

            for file in files:
                yield cls(file, root, is_last())
                count += 1


class Version:
    def __init__(self, major, minor, source_url):
        self.major = major
        self.minor = minor or major
        self.source_url = source_url

    def get_file_url(self, file):
        source_url = self.source_url.replace(r'svn://', '').replace(r'svn+ssh://', '')
        sudo_user = os.getenv('SUDO_USER')
        if sudo_user is not None:
            source_url = sudo_user + '@' + source_url
        source_url = r'svn+ssh://' + source_url
        return source_url + '/' + str(file) + '@' + str(self.minor)

    def __repr__(self):
        if self.major != self.minor:
            return '1.{}.{}-1'.format(self.major, self.minor)
        return '1.{}-1'.format(self.major)


class JavaHotfix:

    def __init__(self):
        self.args = None

        self.working_dir = None
        self.classpath_dir = None

        self.version = None
        self.version_dir = None
        self.original_dir = None
        self.patched_dir = None
        self.classes_dir = None

    def common(self, args):
        self.args = args

        if 'func' not in args:
            parser.print_help()
            exit()

        if args.app is not None:
            if args.app not in APPS:
                app_list = ', '.join(APPS.keys())
                sys.exit("\nERROR: Unknown app '{}'. Possible values: {}".format(args.app, app_list))

            app = APPS[args.app]

            if args.classpath_dir is None:
                args.classpath_dir = app['classpath-dir']

            if args.working_dir is None:
                args.working_dir = app['working-dir']

        if args.working_dir is None and args.classpath_dir is None:
            sys.exit('Specify app (--app, possible values: {}) '
                     'or working directories (--classpath-dir and --working-dir)'.format(', '.join(APPS.keys())))

        if args.working_dir is None:
            sys.exit('\nERROR: Working directory is not specified (--working-dir)')

        if args.classpath_dir is None:
            sys.exit('\nERROR: Classpath directory is not specified (--classpath-dir)')

        self.working_dir = Path(args.working_dir).absolute()
        self.classpath_dir = Path(args.classpath_dir).absolute()

        if not self.classpath_dir.is_dir():
            sys.exit('\nERROR: Directory does not exist: {}'.format(self.classpath_dir))

        if not self.working_dir.exists():
            self.working_dir.mkdir(parents=True)

        if not self.working_dir.is_dir():
            sys.exit('\nERROR: Not a directory: {}'.format(self.working_dir))

        self.version = self.get_version()

        self.version_dir = self.working_dir / str(self.version)
        self.original_dir = self.version_dir / 'original'
        self.patched_dir = self.version_dir / 'patched'
        self.classes_dir = self.version_dir / 'classes'

        debug('Version: ', self.version)
        debug()

        args.func(args)

    def info(self, args):
        print('Working directory:', str(self.version_dir))
        print('Classpath directory:', str(self.classpath_dir))

        def print_dir(dir):
            if dir.exists():
                for path in DirectoryTree.make_tree(dir):
                    print(path.display())
            else:
                print("Directory '{}' does not exist".format(dir))

        print()
        print("Original sources:")
        print_dir(self.original_dir)
        print()
        print("Patched sources:")
        print_dir(self.patched_dir)
        print()
        print("Compiled classes:")
        print_dir(self.classes_dir)

    def checkout(self, args):
        for file in args.files:
            file = arcadia_url_to_file_path(file)

            svn_url = self.version.get_file_url(file)
            original_path = Path(self.original_dir / file)

            if file.startswith('/'):
                file = file[1:]

            if file.startswith('arcadia'):
                debug("\nWARN: File path starts with 'arcadia'. Example path: "
                      "'direct/jobs/src/main/java/ru/yandex/direct/jobs/videogoals/VideoGoalBaseJob.java'\n")

            file_content = svn_cat(svn_url)

            if file_content is None:
                sys.exit('\nERROR: Could not checkout file from svn {}'.format(file))

            # Original file
            if original_path.exists() and not args.overwrite_original:
                with original_path.open('r') as f:
                    original_content = f.read()

                if original_content != file_content:
                    sys.exit('\nERROR: File {} is already present in /original/ directory and differs from checked out '
                             'one. This should not happen, file content should be bound to revision. Specify '
                             '-o/--overwrite-original to overwrite'.format('file'))

            if not original_path.parent.exists():
                original_path.parent.mkdir(parents=True)

            with original_path.open('w') as f:
                f.write(file_content)

            # Patched file
            patched_path = Path(self.patched_dir / file)

            if patched_path.exists() and not args.overwrite_patched:
                with patched_path.open('r') as f:
                    patched_content = f.read()

                if patched_content != file_content:
                    sys.exit('\nERROR: File {} is already present in /patched/ directory and differs from checked out '
                             'one. Specify -p/-overwrite-patched to overwrite'.format(patched_path))

            if not patched_path.parent.exists():
                patched_path.parent.mkdir(parents=True)

            with patched_path.open('w') as f:
                f.write(file_content)

            debug('\nFile {} successfully checked out to {}'.format(Path(file).name, str(patched_path)))

    def edit(self, args):
        patched_files = [str(self.patched_dir / file) for file in self.get_source_files()]
        editor = os.getenv('EDITOR', 'vim')

        command = [editor] + patched_files
        debug('$', ' '.join(command))
        subprocess.call(command)

    def diff(self, args):
        patched_files = self.get_source_files()

        content = {}
        for file in patched_files:
            original_path = self.original_dir / file
            patched_path = self.patched_dir / file

            original_content = []
            patched_content = []

            if original_path.exists():
                with original_path.open('r') as f:
                    original_content = f.readlines()

            if patched_path.exists():
                with patched_path.open('r') as f:
                    patched_content = f.readlines()

            content[file] = (original_content, patched_content)

        self.print_diff(content, args.color)

    def compile(self, args):
        recreate_dir(self.classes_dir)

        source_files = list(get_dir_files(self.patched_dir))

        not_java_files = [str(self.patched_dir / file) for file in source_files if file.suffix != '.java']
        for file in not_java_files:
            debug('WARN: Not .java file in patched dir: ' + file)

        source_files = [str(self.patched_dir / file) for file in source_files if file.suffix == '.java']

        command = [JAVAC, '-g', '-encoding', 'UTF-8', '-cp', str(self.classpath_dir) + '/*',
                   '-d', str(self.classes_dir)] + source_files
        debug('$', ' '.join(command))
        subprocess.call(command)

    def diff_classes(self, args):
        class_files = list(get_dir_files(self.classes_dir))

        not_class_files = [str(self.classpath_dir / file) for file in class_files if file.suffix != '.class']
        for file in not_class_files:
            debug('WARN: Not .class file in classes dir: ' + str(file))

        class_files = [file for file in class_files if file.suffix == '.class']

        content = {}
        if args.method == 'javap':
            for file in class_files:
                class_name = str(file).replace('.class', '')

                patched_content = run_javap(['-cp', str(self.classes_dir), class_name]).splitlines(True)
                classpath_content = run_javap(['-cp', str(self.classpath_dir) + '/*', class_name]).splitlines(True)

                if args.remove_cp_ref:
                    def remove_ref(line):
                        return re.sub(r'#\d+', lambda x: '#' + 'X' * (len(x.group()) - 1), line)

                    patched_content = [remove_ref(line) for line in patched_content]
                    classpath_content = [remove_ref(line) for line in classpath_content]

                content[file] = (classpath_content, patched_content)
            self.print_diff(content, args.color)

        elif args.method == 'fernflower':
            with tempfile.TemporaryDirectory() as tmp_dir:
                tmp_dir_path = Path(tmp_dir)

                for file in class_files:
                    class_name = file.stem
                    class_full_path = str(file.parent / class_name).replace('/', '.')

                    patched_file = self.classes_dir / file
                    run_fernflower(patched_file, tmp_dir_path)
                    with (tmp_dir_path / (class_name + '.java')).open('r') as f:
                        patched_content = f.read().splitlines(True)

                    original_file = tmp_dir_path / (class_name + '.class')
                    with original_file.open('wb') as f:
                        self.write_class_to_file(class_full_path, f)
                    run_fernflower(original_file, tmp_dir_path)
                    with (tmp_dir_path / (class_name + '.java')).open('r') as f:
                        original_content = f.read().splitlines(True)

                    content[file] = (original_content, patched_content)

            self.print_diff(content, args.color)
        else:
            options = ', '.join(['javap', 'fernflower'])
            sys.exit("\nUnknown decompilation method '{}', possible values: {}".format(args.method, options))

    def clean(self, args):
        if self.version_dir.exists():
            shutil.rmtree(str(self.version_dir))
        debug("Successfully deleted directory {}".format(self.version_dir))

    def clean_classes(self, args):
        if self.classes_dir.exists():
            shutil.rmtree(str(self.classes_dir))
        debug("Successfully deleted directory {}".format(self.classes_dir))

    @staticmethod
    def print_diff(content, color=False):
        for (file, (original_content, patched_content)) in content.items():
            diff_func = difflib.unified_diff if not color else colored_diff
            diff_lines = diff_func(original_content, patched_content, fromfile=file, tofile=file)
            diff_lines = list(diff_lines)
            if len(diff_lines) > 0:
                sys.stdout.writelines(diff_lines)
                print()

    def get_source_files(self):
        original, patched = get_dir_files(self.original_dir), get_dir_files(self.patched_dir)
        return list(set().union(original, patched))

    def get_version(self):
        command = [JAVA, '-cp', str(self.classpath_dir) + '/*', GET_VERSION_TOOL_CLASS]

        try:
            debug('$', ' '.join(command))
            result = subprocess.check_output(command, universal_newlines=True)
        except subprocess.CalledProcessError as e:
            sys.exit('\nERROR: Could not get app version, process exit code is {}'.format(e.returncode))

        version_info = json.loads(result)

        source_url = version_info['sourceUrl'] or None  # empty string -> None
        major = to_int_or_none(version_info['major'])
        minor = to_int_or_none(version_info['minor'])

        if not major or major <= 0 or not minor or minor <= 0 or not source_url:
            sys.exit('\nERROR: Something wrong with app version. This can happen with custom builds: {}'.format(result))

        version = Version(major, minor, source_url)

        return version

    def write_class_to_file(self, class_name, file):
        command = [JAVA, '-cp', str(self.classpath_dir) + '/*', WRITE_CLASS_TO_STDOUT_TOOL_CLASS, class_name]

        try:
            debug('$', ' '.join(command))
            subprocess.check_call(command, stdout=file)
        except subprocess.CalledProcessError as e:
            sys.exit('\nERROR: Could not get class from app classpath {}'.format(e.returncode))


def arcadia_url_to_file_path(file):
    regex = r"^https?://a.yandex-team.ru/arc/(?:trunk/arcadia/|branches/direct/release/[a-z-]+/\d+/arcadia/)" \
            r"(.*?)(?:#L\d+)?$"
    match = re.match(regex, file)
    if match:
        return match.group(1)
    return file


def run_fernflower(files, destination):
    if not isinstance(files, list):
        files = [files]

    files = list(map(str, files))

    command = [JAVA, '-jar', FERNFLOWER_JAR] + files + [str(destination)]
    try:
        debug('$', ' '.join(command))
        subprocess.check_call(command, universal_newlines=True)
    except subprocess.CalledProcessError as e:
        sys.exit('\nERROR: fernflower process exited with code {}'.format(e.returncode))


def run_javap(args):
    command = [JAVAP, '-p', '-c', '-v', '-constants'] + args
    try:
        debug('$', ' '.join(command))
        result = subprocess.check_output(command, universal_newlines=True)
    except subprocess.CalledProcessError as e:
        sys.exit('\nERROR: javap process exited with code {}'.format(e.returncode))

    return result


def recreate_dir(dir):
    if dir.exists():
        shutil.rmtree(str(dir))
    dir.mkdir(parents=True)


def get_dir_files(dir):
    for path, directories, files in os.walk(str(dir)):
        for file in files:
            if not file.startswith('.'):
                file_path = Path(path) / file
                yield file_path.relative_to(dir)


def svn_cat(svn_url):
    command = ['svn', 'cat', svn_url]
    try:
        debug('$', ' '.join(command))
        result = subprocess.check_output(command, universal_newlines=True)
    except subprocess.CalledProcessError:
        return None
    return result


# https://github.com/santazhang/Python-3.4.0/blob/master/Lib/difflib.py#L1138
def colored_diff(a, b, fromfile='', tofile='', fromfiledate='',
                 tofiledate='', n=3, lineterm='\n'):

    started = False
    for group in difflib.SequenceMatcher(None,a,b).get_grouped_opcodes(n):
        if not started:
            started = True
            fromdate = '\t{}'.format(fromfiledate) if fromfiledate else ''
            todate = '\t{}'.format(tofiledate) if tofiledate else ''
            yield '--- {}{}{}'.format(fromfile, fromdate, lineterm)
            yield '+++ {}{}{}'.format(tofile, todate, lineterm)

        first, last = group[0], group[-1]
        file1_range = difflib._format_range_unified(first[1], last[2])
        file2_range = difflib._format_range_unified(first[3], last[4])
        yield '@@ -{} +{} @@{}'.format(file1_range, file2_range, lineterm)

        for tag, i1, i2, j1, j2 in group:
            if tag == 'equal':
                for line in a[i1:i2]:
                    yield ' ' + line
                continue
            if tag in {'replace', 'delete'}:
                for line in a[i1:i2]:
                    yield '\033[0;31m' + '-' + line + '\033[0;0m'
            if tag in {'replace', 'insert'}:
                for line in b[j1:j2]:
                    yield '\033[0;32m' + '+' + line + '\033[0;0m'


def to_int_or_none(value):
    try:
        return int(value)
    except ValueError:
        return None


if __name__ == "__main__":
    java_hotfix = JavaHotfix()

    parser = ArgumentParser(formatter_class=RawDescriptionHelpFormatter, epilog="""
Synopsis:

dt-java-hotfix -a logviewer info

dt-java-hotfix -a logviewer checkout direct/common/src/main/java/ru/yandex/direct/common/alive/AliveReporterServlet.java

dt-java-hotfix -a logviewer edit

dt-java-hotfix -a logviewer diff
dt-java-hotfix -a logviewer diff --no-color

dt-java-hotfix -a logviewer compile

dt-java-hotfix -a logviewer diff-classes
dt-java-hotfix -a logviewer diff-classes --method fernflower
dt-java-hotfix -a logviewer diff-classes --no-color --remove-cp-ref

dt-java-hotfix -a logviewer clean
dt-java-hotfix -a logviewer clean-classes

Classpath and working directories can be specified instead of application name:

cd /var/www/beta.darkkeks.10490
dt-java-hotfix -c java/direct/web/yandex-direct-web -w java/direct/web/hotfix
""")
    parser.add_argument('-a', '--app', help='Application, one of {' + ', '.join(APPS.keys()) + '}')
    parser.add_argument('-c', '--classpath-dir', help='Application classpath directory')
    parser.add_argument('-w', '--working-dir', help='Hotfix working directory')

    subparsers = parser.add_subparsers()

    info_parser = subparsers.add_parser('info', help="Current hotfix status")
    info_parser.set_defaults(func=java_hotfix.info)

    checkout_parser = subparsers.add_parser('checkout', help='Checkout svn source file')
    checkout_parser.set_defaults(func=java_hotfix.checkout)
    checkout_parser.add_argument("files", nargs='+')
    checkout_parser.add_argument("-o", "--overwrite-original", action='store_true')
    checkout_parser.add_argument("-p", "--overwrite-patched", action='store_true')

    edit_parser = subparsers.add_parser('edit', help='Open checked out sources in $EDITOR')
    edit_parser.set_defaults(func=java_hotfix.edit)

    diff_parser = subparsers.add_parser('diff', help='Diff patched and original sources')
    diff_parser.set_defaults(func=java_hotfix.diff)
    diff_parser.add_argument("--no-color", action="store_false", dest="color")

    compile_parser = subparsers.add_parser('compile', help='Compile patched sources')
    compile_parser.set_defaults(func=java_hotfix.compile)

    diff_classes_parser = subparsers.add_parser('diff-classes', help='Diff patched and classpath classes')
    diff_classes_parser.set_defaults(func=java_hotfix.diff_classes)
    diff_classes_parser.add_argument("-m", "--method", choices=["javap", "fernflower"], default="javap",
                                     help="Decompilation method")
    diff_classes_parser.add_argument("--no-color", action="store_false", dest="color")
    diff_classes_parser.add_argument("--remove-cp-ref", action="store_true",
                                     help="Replace constant pool references with #XXX")

    clean_parser = subparsers.add_parser('clean', help="Remove current version working directory")
    clean_parser.set_defaults(func=java_hotfix.clean)

    clean_classes_parser = subparsers.add_parser('clean-classes', help="Remove current version compiled classes")
    clean_classes_parser.set_defaults(func=java_hotfix.clean_classes)

    java_hotfix.common(parser.parse_args())
