from __future__ import print_function
from typing import List, OrderedDict, Tuple, Dict, Optional

from pathlib import Path
from subprocess import run, PIPE
from sys import stdout as stdout
from dataclasses import dataclass

import os
import logging
from encodings import utf_8
from datetime import datetime

import xmltodict


UTF_8 = utf_8.getregentry().name


@dataclass(frozen=True)
class SvnLogEntry:
    revision: int
    author: str
    date: datetime
    message: str

    def __hash__(self):
        return hash(self.revision)

    def __eq__(self, other):
        return self.revision == other.revision


class App:
    REFERENCE_REVISION = 8984475

    ARCADIA_ROOT = Path(os.environ["ARCADIA_ROOT"])
    YA_BINARY = ARCADIA_ROOT / "ya"
    PROJECT_FOLDER = ARCADIA_ROOT / "tasklet/experimental/registry_data_size_calculator"
    TEST_REGISTRY_FOLDER = PROJECT_FOLDER / "registry"
    TEST_REGISTRY_YA_MAKE = TEST_REGISTRY_FOLDER / "ya.make"
    SVN_LOG_CACHE_FOLDER = PROJECT_FOLDER / "svn_log_cache"
    PROJECT_SIZE_CACHE_FILE = PROJECT_FOLDER / "size_cache.csv"
    OUT_FILE = PROJECT_FOLDER / "proto_desc_size_diff.csv"
    RULES_FILES = [
        "build/rules/junk.policy",
        "build/rules/go/contrib.policy",
        "build/rules/go/vendor.policy",
        "build/rules/contrib_deprecated.policy",
        "build/rules/contrib_python.policy",
        "build/rules/contrib_restricted.policy",
        "build/rules/contrib_deps.policy",
        "build/rules/library_deps.policy",
        "build/rules/library_deprecated.policy",
        "build/rules/passport.policy",
        "build/rules/yt.policy",
        "build/rules/catboost.policy",
        "build/rules/maps/maps.policy",
        "build/rules/taxi.policy",
        "build/rules/yp.policy",
        "build/rules/alice.policy",
        "build/rules/kikimr.policy",
        "build/rules/yadi.policy",
    ]

    EXCLUDED_DIRS = [
        "junk/",
        "devtools/dummy_arcadia/",
        "contrib/",
        "/internal/",
        "market/bootcamp",

        # Broken ya.make files
        "infra/nanny/nanny_internal/proto",
        "robot/zora/tools/source_converter/protos",
        "search/begemot/rules/yabs_caesar_models/caesar_models/config",
        "search/begemot/rules/yabs_caesar_models/caesar_models_setup/config",
        "search/begemot/rules/yabs_caesar_models/caesar_multiks/config",
        "search/begemot/rules/yabs_caesar_models/caesar_multiks_merger/config",
        "search/begemot/rules/yabs_caesar_models/caesar_multiks_setup/config",
        "search/begemot/rules/yabs_caesar_models/caesar_phrase_models/config",
        "search/begemot/rules/yabs_page_models/config",
        "search/itditp/tools/dssm/real_time_training/dssm_trainer/dict/config",
        "taxi/uservices/gen/schemas/proto/envoy",
        "extsearch/geo/fast_export/protos/patch",

        # Broken by unknown reason (probably tasklet macro)
        "arc/robots/outstaffsync/proto",
        "disk/tasklets/conductor/proto",
        "disk/tasklets/mpfs_release/proto/",
        "disk/tasklets/startrek/proto",
        "disk/tasklets/teamcity/proto",

        # Broken by unreachable peerdir
        "infra/nanny/nanny_repo/proto",
        "search/proto/binlog",
        "taxi/uservices/gen/schemas/proto/market",

        # Broken at compile time for variouse reasons
        "yql/library/proto_ast",
        "kinopoisk/apps/templates",
    ]

    def __init__(self):
        logging.basicConfig(stream=stdout, level=logging.INFO)

    def run(self):
        proto_projects = self.get_proto_projects_list()
        pr_size = self.get_projects_size_bytes(proto_projects)
        logs, proto_projects_by_revision = self.load_logs()
        logging.info("All info loaded")
        logging.info(f"Sum of sizes = {sum(pr_size.values())}")

        with open(self.OUT_FILE, "w") as out_file:
            for log_entry in logs:
                revision = log_entry.revision
                date = log_entry.date.isoformat()
                size_sum = 0
                affected_projects = proto_projects_by_revision[revision]
                for affected_project in affected_projects:
                    size_sum += pr_size[affected_project]
                out_file.write(f"{revision},{date},{size_sum},{':'.join(affected_projects)}\n")

        logging.info("File was filled")

    def prepare_registry_ya_make(self, proto_projects: List[str]):
        self.build_registry_ya_make(proto_projects)

    def load_logs(self) -> Tuple[List[SvnLogEntry], Dict[int, List[str]]]:
        self.SVN_LOG_CACHE_FOLDER.mkdir(parents=True, exist_ok=True)

        denied_directories = self.load_denied_directories()
        proto_projects = self.find_proto_libraries(denied_directories)

        logs_with_proto_changes = set()
        proto_projects_by_revision: Dict[int, List[str]] = dict()
        for project in proto_projects:
            revisions_for_project = self.load_svn_log(project)
            logs_with_proto_changes.update(revisions_for_project)
            for rev in revisions_for_project:
                if rev.revision not in proto_projects_by_revision:
                    proto_projects_by_revision[rev.revision] = []
                proto_projects_by_revision[rev.revision].append(project)

        logs_with_proto_changes = list(sorted(logs_with_proto_changes, key=lambda l: l.revision))

        logging.info(f"Found {len(logs_with_proto_changes)} logs with proto changes")
        return logs_with_proto_changes, proto_projects_by_revision

    def get_proto_projects_list(self) -> List[str]:
        denied_directories = self.load_denied_directories()
        proto_projects = self.find_proto_libraries(denied_directories)
        return proto_projects

    def get_projects_size_bytes(self, projects: List[str], cache=True) -> Dict[str, int]:
        projects_size: Dict[str, int] = {}

        if cache and self.PROJECT_SIZE_CACHE_FILE.exists():
            for line in self.PROJECT_SIZE_CACHE_FILE.read_text(UTF_8).split("\n"):
                if line.strip():
                    [project, size] = line.split(",", maxsplit=2)
                    projects_size[project] = int(size)

        with open(self.PROJECT_SIZE_CACHE_FILE, "a") as cache_file:
            for project in projects:
                if cache and project in projects_size:
                    pass
                else:
                    project_size_bytes = self.get_project_size(project)
                    projects_size[project] = project_size_bytes
                    cache_file.write(f"{project},{project_size_bytes}\n")
        return projects_size

    def find_proto_libraries(self, denied_directories: List[str]) -> List[str]:
        exclude_str = "(" + "|".join(self.EXCLUDED_DIRS + denied_directories) + ")"
        command = [
            str(self.YA_BINARY),
            "tool",
            "cs",
            "--file", "^.+ya\\.make$",
            "--exclude", exclude_str,
            "--files-with-matches",
            "--no-totals",
            "--max", "0",
            "--fixed-strings",
            "PROTO_LIBRARY"
        ]
        logging.info(f"Run command {' '.join(command)}")
        result = run(
            command,
            check=True,
            cwd=self.ARCADIA_ROOT,
            stdout=PIPE,
        )
        files = []
        for r in result.stdout.decode(UTF_8).split("\n"):
            r = r.strip()
            if r:
                if not r.endswith("/ya.make"):
                    logging.warning(f"File {r} is note ya.make and will be skipped")
                    continue
                for d in self.EXCLUDED_DIRS + denied_directories:
                    if r.startswith(d):
                        logging.warning(f"File {r} is in denied directory {d} and will be skipped")
                        continue
                file = self.ARCADIA_ROOT / r
                if not file.exists():
                    logging.warning(f"File {r} not exists")
                    continue
                if not self.descriptor_supported_in_file(file):
                    continue
                files.append(r.removesuffix("/ya.make"))
        return files

    def build_registry_ya_make(self, files: List[str]) -> None:
        self.TEST_REGISTRY_FOLDER.mkdir(parents=True, exist_ok=True)

        with open(self.TEST_REGISTRY_YA_MAKE, "w") as ya_make_file:
            ya_make_file.write("PROTO_REGISTRY()\n")
            ya_make_file.write("PEERDIR(\n")
            for file_path in files:
                ya_make_file.write(f"    {file_path}\n")
            ya_make_file.write(")\n")
            ya_make_file.write("END()\n")

        logging.info(f"File {self.TEST_REGISTRY_YA_MAKE} successfully created")

    def load_denied_directories(self):
        denied_directories = []
        for policy_file in self.RULES_FILES:
            policy_file = self.ARCADIA_ROOT / policy_file

            logging.debug(f"Parsing {policy_file}")

            lines = policy_file.read_text(UTF_8).split("\n")
            for line in lines:
                line = line.strip()
                if line.startswith("DENY "):
                    [source, destination] = line.removeprefix("DENY ").split("->", maxsplit=2)
                    source = source.strip()
                    destination = destination.strip()

                    logging.debug(f"Parsed {line} as {source} to {destination}")

                    if source == ".*":
                        denied_directories.append(destination)

        for d in sorted(denied_directories):
            logging.debug(f"Denied directory {d}")

        return denied_directories

    def descriptor_supported_in_file(self, file: Path):
        content = file.read_text(UTF_8)
        if "ONLY_TAGS" in content:
            logging.warning(f"Skip file {file} as it do not build descriptor info")
            return False
        if "PROTO_LIBRARY" not in content:
            logging.warning(f"Skip file {file} as it is no longer proto library")
            return False
        return True

    def get_revision_with_arc(self):
        command = [
            "arc",
            "info"
        ]
        result = run(
            command,
            check=True,
            stdout=PIPE,
        )
        for line in result.stdout.decode(UTF_8).split("\n"):
            if line.startswith("revision: "):
                return line.removeprefix("revision: ").strip()
        raise Exception("No revision info found")

    def load_svn_log(self, arcadia_path: str, cache=True) -> List[SvnLogEntry]:
        cache_path = self.SVN_LOG_CACHE_FOLDER / arcadia_path / "log.xml"
        if cache and cache_path.exists():
            text = cache_path.read_text(UTF_8)
            result_dict = xmltodict.parse(text)
        else:
            command = [
                "svn",
                "log",
                "--xml",
                "svn+ssh://arcadia.yandex.ru/arc/trunk/arcadia/",
                arcadia_path,
            ]
            result = run(
                command,
                stdout=PIPE,
                check=True,
            )
            result_dict = xmltodict.parse(result.stdout)
            if cache:
                cache_path.parent.mkdir(parents=True, exist_ok=True)
                cache_path.write_text(result.stdout.decode(UTF_8), UTF_8)

        result = []
        container = result_dict["log"]["logentry"]
        if isinstance(container, list):
            for entry in container:
                logging.debug(f"For project {arcadia_path} loaded log entry {entry}")
                result.append(self.convert_dict_to_entry(entry))
        else:
            logging.debug(f"For project {arcadia_path} loaded log entry {container}")
            result.append(self.convert_dict_to_entry(container))
        return result

    @staticmethod
    def convert_dict_to_entry(entry: OrderedDict) -> SvnLogEntry:
        return SvnLogEntry(
            revision=int(entry["@revision"]),
            author=entry["author"],
            date=datetime.strptime(entry["date"], "%Y-%m-%dT%H:%M:%S.%f%z"),
            message=entry["msg"],
        )

    def get_project_size(self, project_path: str) -> int:
        file = self.get_protodesc_file(project_path)
        if file:
            logging.info(f"Get size from already built project {project_path}")
            return file.stat().st_size
        else:
            command = [
                str(self.YA_BINARY),
                "make",
                "--replace-result",
                "--add-result=.protodesc",
            ]
            cwd = self.ARCADIA_ROOT / project_path
            logging.info(f"Run {' '.join(command)} at {cwd}")
            result = run(
                command,
                cwd=cwd,
                stdout=PIPE,
            )
            if result.returncode == 0:
                file = self.get_protodesc_file(project_path)
                if file:
                    return file.stat().st_size
                else:
                    logging.warning(f"No file found after build for {project_path}")
                    return 0
            else:
                logging.warning(f"Broken project {project_path}")
                return 0

    def get_protodesc_file(self, project_path: str) -> Optional[Path]:
        path = self.ARCADIA_ROOT / project_path
        descriptors = list(path.glob("*.protodesc"))
        if len(descriptors) == 0:
            return None
        if len(descriptors) > 1:
            raise Exception(f"Found multiple descriptors for project {project_path}")

        return descriptors[0].readlink()
