# coding: utf-8
import re
import os
import math
import json
import logging as log
from enum import Enum
from typing import Dict, List, Tuple, NamedTuple
from datetime import datetime, timedelta
from collections import defaultdict

import boto3
import botocore
import requests

import sandbox.sdk2 as sdk2
import sandbox.projects.common.binary_task as binary_task

DEFAULT_ENDPOINT = "https://s3.mds.yandex.net"
JUGGLER_PUSH_API = "http://juggler-push.search.yandex.net/events"
EXPEC_NEXT_BACKUP_AFTER_HOURS = 35

Policy = NamedTuple("Policy", [("daily", int), ("weekly", int), ("monthly", int)])
BINARY_TASK_ATTR_TARGET = "media/admins/MediaRotateS3Backups"


class InvalidBackupPath(Exception):
    pass


class MediaRotateS3Backups(sdk2.Task):
    """Rotate media's db s3 backups"""

    BINARY_TASK_ATTR_TARGET = BINARY_TASK_ATTR_TARGET

    class Requirements(sdk2.Task.Requirements):
        cores = 1

    class Parameters(sdk2.Task.Parameters):

        dry_run = sdk2.parameters.Bool("Dry run, don't write/remove objects", required=False, default=True)
        prefix = sdk2.parameters.String(
            "Database backups path (e.g. music/backup/mysql where 'music' is S3 bucket name)", required=True)

        with sdk2.parameters.Group('Monitoring') as monitoring:
            check = sdk2.parameters.Bool("Run only juggler check, don't start rotation", required=False, default=False)
            with check.value[True]:
                skip_count_check = sdk2.parameters.Bool(
                    "Check expected count of the backups by policy?",
                    description="By default only check existence of the last backup",
                    required=False, default=True)
                skip_size_check = sdk2.parameters.Bool("Do not check backups size", required=False, default=False)
                with skip_size_check.value[True]:
                    skip_size_check_for_names = sdk2.parameters.List(
                        "Skip size checks for all db names or for specified only?", required=False, default=list(),
                    )
                max_size_diff_up_prc = sdk2.parameters.Integer(
                    "Maximum increase in the backup size as a percentage", required=False, default=50)
                max_size_diff_down_prc = sdk2.parameters.Integer(
                    "Maximum decrease in the backup size as a percentage", required=False, default=10)

        with sdk2.parameters.Group('Rotation policy') as policy:
            rotate_from_last_backup = sdk2.parameters.Bool(
                "Rotate from last backup (default from time.now())", required=False, default=False)
            daily_count = sdk2.parameters.Integer("Preserve N daily backups", required=True, default=5)
            weekly_count = sdk2.parameters.Integer("Preserve N weekly backups", required=True, default=1)
            monthly_count = sdk2.parameters.Integer("Preserve N monthly backups", required=True, default=1)
            check_path_len = sdk2.parameters.Bool(
                "Check if path to backup is >= 4 dirs",
                description="Paranoid mode: if path to backup is less than 4, Fail task (path too short) " +
                            "If False, then at least check that last dir is matched by regexp for DATE_RE, " +
                            "but do not check path length.",
                required=False,
                default=True  # Not to break old behaviour
            )
            check_completed_re = sdk2.parameters.Bool(
                "Check if backup completed by regexp",
                required=False,
                default=True  # Not to break old behaviour
            )

        with sdk2.parameters.Group('S3') as s3params:
            endpoint = sdk2.parameters.String("S3 endpoint", required=True, default=DEFAULT_ENDPOINT)
            secret_id = sdk2.parameters.String(
                "S3 auth yav secret-id",
                required=True,
                description="Yav secret id with S3 access/secret keys",
                default="sec-01czneeqg8pt8n41qbaam67yk5"
            )
            s3_access_key_key = sdk2.parameters.String(
                "S3 access_key key in the yav secret", required=True, default="s3_prod_access_key")
            s3_secret_key_key = sdk2.parameters.String(
                "S3 secret_key key in the yav secret", required=True, default="s3_prod_secret_key")

        with sdk2.parameters.Group('Task executor') as task_executor:
            ext_params = binary_task.binary_release_parameters(stable=True)

    def on_execute(self):

        secret = sdk2.yav.Secret(self.Parameters.secret_id)  # последняя версия секрета
        secret_data = secret.data()

        s3client = S3Client(
            path=self.Parameters.prefix,
            access_key=secret_data[self.Parameters.s3_access_key_key],
            secret_key=secret_data[self.Parameters.s3_secret_key_key],
            check_path_len=self.Parameters.check_path_len,
            check_completed_re=self.Parameters.check_completed_re,
            dry_run=self.Parameters.dry_run,
            endpoint=self.Parameters.endpoint
        )

        policy = Policy(**{
            BackupItem.DAILY: self.Parameters.daily_count,
            BackupItem.WEEKLY: self.Parameters.weekly_count,
            BackupItem.MONTHLY: self.Parameters.monthly_count
        })

        if self.Parameters.check:
            self.set_info(f"Run juggler check s3://{self.Parameters.prefix}")
            monitor = BackupMonitoring(
                policy=policy,
                s3client=s3client,
                skip_count_check=self.Parameters.skip_count_check,
                skip_size_check=self.Parameters.skip_size_check,
                skip_names=self.Parameters.skip_size_check_for_names,
                max_size_diff_up=self.Parameters.max_size_diff_up_prc,
                max_size_diff_down=self.Parameters.max_size_diff_down_prc,
            )
            events = monitor.run()
            info_msg = '\n'.join(f"{e.service}:{e.status}:{e.description}" for e in events)
            self.set_info(f"Juggler events: \n{info_msg}")
            sended = monitor.send_events_to_juggler(events)
            self.set_info(f"Successfully sent {sended} events")
        else:
            self.set_info(f"Run backup rotation s3://{self.Parameters.prefix}")
            rotator = BackupRotate(
                self.Parameters.rotate_from_last_backup,
                policy,
                s3client
            )
            removed = rotator.run()
            self.set_info(f"Rotator remove {removed} backups")


class BackupItem(object):
    DAILY = "daily"
    WEEKLY = "weekly"
    MONTHLY = "monthly"

    NONE = "none"
    NORMAL = "normal"
    TRANSITIONAL = "transitional"

    SUPPORTED_FREQUENCIES = (DAILY, WEEKLY, MONTHLY)

    DATE_RE = re.compile(r"^(?P<year>20\d\d)-?(?P<month>\d\d)-?(?P<day>\d\d)$")

    def __init__(self, path, state=NONE):
        #                /type / name                            / date
        # path e.g backup/mongo/music-stable-mongo-main-mongodb-1/20201211/
        path_arr = path.strip("/").split("/")

        date_string = path_arr.pop(-1)
        # pop from path_arr last path element, expect date like "YYYY-?MM-?DD"
        date = BackupItem.DATE_RE.match(date_string)
        if not date:
            raise InvalidBackupPath(path)

        self.name = path_arr.pop(-1)
        self.type = path_arr.pop(-1)
        # date should contains 3 group with parsed numbers, see DATE_RE
        self.timestamp = datetime(*[int(x, 10) for x in date.groups(0)])
        self.state = state
        self.path = path

    def my_key(self, freq: str):
        return self.key(freq, self.timestamp)

    @ staticmethod
    def key(freq: str, ts: datetime):
        """Get key by frequency name and timestamp"""
        if freq == BackupItem.WEEKLY:
            week = ts.isocalendar()[1]  # week number
            return f"{ts.year} week {week}"
        if freq == BackupItem.MONTHLY:
            return f"{ts.year} month {ts.month:02d}"
        return f"{ts.year}-{ts.month:02d}-{ts.day:02d}"

    def __repr__(self):
        return f'<{self.__class__.__name__}(path={self.path}, state={self.state})>'

    def __str__(self):
        return self.path


class S3Client(object):

    # backup/mongo/music-...-1/20201101/mongo-backup01i.../shard.dump.2020.11.01.tar.zst
    MONGO_BACKUP_COMPLETED_RE = re.compile(r'.*/shard\.dump\.[^/]+\.tar\.(gz|zst)$')
    # backup/mysql/main-mt/2020-11-01/00xxx-...-success.sql.zst
    MYSQL_BACKUP_COMPLETED_RE = re.compile(r'.*/\d+[^/]+-success.sql.(gz|zst)$')
    # backup/cassandra/music-cass01j.music.yandex.net/20201101/data.tzst
    CASSANDRA_BACKUP_COMPLETED_RE = re.compile(r'.*/data\.(tzst|tgz)$')

    def __init__(self,
                 path: str,
                 access_key: str,
                 secret_key: str,
                 check_path_len: bool = True,
                 check_completed_re: bool = True,
                 dry_run: bool = True,
                 endpoint: str = DEFAULT_ENDPOINT
                 ):

        self.check_path_len = check_path_len
        self.check_completed_re = check_completed_re
        self.dry_run = dry_run
        session = boto3.Session(aws_access_key_id=access_key,
                                aws_secret_access_key=secret_key)
        s3 = session.resource("s3", endpoint_url=endpoint)
        bucket, _, prefix = path.partition("/")
        self.bucket = s3.Bucket(bucket)
        self.prefix = prefix

    def rm(self, backup: BackupItem):
        """rm remove backup on s3 storage"""
        path = backup.path

        # XXX paranoic mode on
        parts = path.strip("/").split("/")
        parts_len = len(parts)
        if self.check_path_len and parts_len < 4:
            raise Exception(f"Failed to remove backup {backup}, path too short")
            # XXX paranoic mode off
        elif not self.check_path_len:
            date_matches = backup.DATE_RE.findall(os.path.basename(os.path.normpath(path)))
            if len(date_matches) == 0:
                raise Exception(f"Failed to remove backup {backup}, path too short")

        if self.dry_run:
            log.info(f"Dry run rm(path={path})")
            return

        result = self.bucket.objects.filter(Prefix=path).delete()
        log.info(f"Delete backup response: {result}")

    def du(self, backup: BackupItem) -> Tuple[int, int]:
        """Check backup size"""
        size = 0
        objects = 0
        for obj in self.bucket.objects.filter(Prefix=backup.path):
            if "/binlogs/mysql-bin" in obj.key:
                # Тут тоже придётся вставить небольшой кусочек
                # знания о структуре бекапа mysql.
                # Нужно пропустить все объекты из бекапа бинлога,
                # который находится в подкаталоге с бекапом основных данных.
                # Пропускаем, потому-что бинлоги доливается в течении дня,
                # тем самым ломая логику проверки размера бекапа.
                # Пример пропускаемых ключей:
                # backup/mysql/disk-mt/2020-12-01/binlogs/mysql-bin.000062
                # .../2020-12-01/binlogs/mysql-bin.000062-from-22648533-to-23820293
                continue
            objects += 1
            size += obj.size
        return size, objects

    def completed(self, backup: BackupItem) -> bool:
        path = backup.path
        log.info(f"Check completion for {path}")

        try:
            self.bucket.Object(os.path.join(path, "DONE")).load()
            return True
        except botocore.exceptions.ClientError as e:
            if e.response['Error']['Code'] != "404":
                raise

        if not self.check_completed_re:
            log.info("Not checking <DB>_BACKUP_COMPLETED_RE, since flag `check_completed_re` is set")
            return True
        elif '/mongo/' in path:
            completed_re = self.MONGO_BACKUP_COMPLETED_RE
        elif '/mysql/' in path:
            completed_re = self.MYSQL_BACKUP_COMPLETED_RE
        elif '/cassandra/' in path:
            completed_re = self.CASSANDRA_BACKUP_COMPLETED_RE
        else:
            log.warn(f"Don't know how to check completion for {path}")
            return False

        resp = self.bucket.meta.client.list_objects_v2(Bucket=self.bucket.name, Prefix=path)
        if 'Contents' not in resp:
            log.warn(f"Can't list objects in {path}")
            return False

        files = [f["Key"] for f in resp["Contents"]]
        is_completed = any(completed_re.match(f) for f in files)
        log.info(f"Found {len(files)} objects in {path}, is completed: {is_completed}")
        return is_completed

    def list_backups(self, prefix: str = None) -> List[BackupItem]:
        """List s3 path"""
        if not prefix:
            prefix = self.prefix
        if prefix and not prefix.endswith("/"):
            prefix += "/"

        backups: List[BackupItem] = []
        dirs = [prefix]
        while dirs:
            next_dir = dirs.pop(0)
            log.info(f"Descent dir {next_dir}")

            lo = self.bucket.meta.client.get_paginator("list_objects")
            for resp in lo.paginate(Bucket=self.bucket.name, Prefix=next_dir, Delimiter="/"):
                if "CommonPrefixes" in resp:
                    for obj in resp["CommonPrefixes"]:
                        path = obj["Prefix"]
                        try:
                            item = BackupItem(path)
                            log.info(f"Found backup: {item}")
                            backups.append(item)
                        except InvalidBackupPath:
                            log.info(f"Found directory: '{path}'")
                            dirs.append(path)
        if not backups:
            log.warn(f"Backups for {prefix} not found")
        return backups

    def purge_stale(self, backups_list: List[BackupItem]) -> int:
        """Purge stale backups"""
        removed = 0
        for backup in backups_list:
            if backup.state == BackupItem.NONE:
                log.info(f"Remove stale backup {backup}")
                self.rm(backup)
                removed += 1
            else:
                log.info(f"Preserve backup {backup}")
        return removed


BackupGroupsByFrequency = Dict[str, Dict[str, Dict[str, List[BackupItem]]]]


class BackupRotate(object):
    def __init__(self, from_last: bool, policy: Policy, s3client: S3Client):
        self.rotate_from_last_backup = from_last
        self.policy = policy
        self.s3cli = s3client

    def run(self) -> int:
        log.info("Run backups rotation")
        backups_list = self.s3cli.list_backups()

        from_timestamp = self.get_now_timestamp(backups_list)
        log.info(f"Rotate from timestamp ({from_timestamp})")
        backups_groups = self.group_by_frequency(backups_list)
        self.select_valid(backups_groups, from_timestamp, self.policy)

        removed = self.s3cli.purge_stale(backups_list)
        return removed

    def get_now_timestamp(self, backups: List[BackupItem]) -> datetime:
        """Get most recent backup timestamp"""
        if self.rotate_from_last_backup:
            backups = sorted(backups, key=lambda x: x.timestamp)
            return backups[-1].timestamp
        return datetime.now()

    def group_by_frequency(self, backups: List[BackupItem]) -> BackupGroupsByFrequency:
        backup_groups: BackupGroupsByFrequency = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
        for bkp in backups:
            daily_key = bkp.my_key(BackupItem.DAILY)
            weekly_key = bkp.my_key(BackupItem.WEEKLY)
            monthly_key = bkp.my_key(BackupItem.MONTHLY)

            backup_groups[bkp.name][BackupItem.DAILY][daily_key].append(bkp)
            backup_groups[bkp.name][BackupItem.WEEKLY][weekly_key].append(bkp)
            backup_groups[bkp.name][BackupItem.MONTHLY][monthly_key].append(bkp)
        return backup_groups

    def select_valid(self, backup_groups: BackupGroupsByFrequency, from_ts: datetime, policy: Policy):
        """Select actual backups"""
        # entity - имя базы, frequencies список списков бекапов,
        # сгруппированных по периодам.
        for entity, frequencies in backup_groups.items():
            for frequency in BackupItem.SUPPORTED_FREQUENCIES:
                backups = frequencies.get(frequency, None)
                if not backups:
                    log.info(f"No {frequency} backups for {entity}")
                    continue
                log.info(f"Check {frequency} backups for {entity}")
                # transition - это бекапы, которые попадают на промежуток
                # между последним дневным и первым из более длинных периодов
                # переходный бекап нужно сохранять для того, чтобы при выборе
                # бекапов в более длинных периодах было из чего выбирать.
                transition = 0
                # NOTE force cast daily_count to float
                # this is critical for the selection algorithm
                daily_count_as_float = float(policy.daily)
                if frequency == BackupItem.WEEKLY:
                    transition = math.ceil(daily_count_as_float / 7.0)
                if frequency == BackupItem.MONTHLY:
                    transition = math.ceil((daily_count_as_float + policy.weekly * 7) / 30)

                for delta in range(getattr(policy, frequency, 0) + transition):
                    # chkp указывает на дату, в которой нужно выбирать бекапы
                    # в зависимости от текущего frequency у chkp выбирается либо
                    # день, либо номер недели или месяца, потом из сгруппированных
                    # бекапов берется список для этого периода и самый
                    # новый бекап из этого списка помечается как валидный для сохранения
                    chkp = self.get_checkpoin(from_ts, frequency, delta)

                    bkp_type = BackupItem.NORMAL
                    if frequency != BackupItem.DAILY and delta < transition:
                        bkp_type = BackupItem.TRANSITIONAL
                    log.info(f"Select {bkp_type} backups in {chkp}(ts={from_ts}) by delta={delta}")
                    backups_list = backups.get(chkp, [])
                    if not backups_list:
                        log.info(f"Not found {bkp_type} backups in selected period")
                        continue
                    bkp = sorted(backups_list, key=lambda x: x.timestamp)[0]
                    if bkp.state == BackupItem.NONE:
                        log.info(f"Set state={bkp_type} for backup: {bkp}")
                        bkp.state = bkp_type

    def get_checkpoin(self, timestamp: datetime, freq: str, factor: int) -> str:
        """relativedelta by frequency"""
        days = 1 * factor
        if freq == BackupItem.WEEKLY:
            days = 7 * factor
        if freq == BackupItem.MONTHLY:
            # tmp.day - contains day of month and delta will be substracted from
            # timestamp, therefore tmp.day + 1 enought to get prev month
            # and repeat it for 'factor' times
            tmp = timestamp
            days = 0
            for _ in range(factor):
                days += tmp.day
                tmp = tmp - timedelta(days=tmp.day)

        chkp_ts = timestamp - timedelta(days=days)
        return BackupItem.key(freq, chkp_ts)


BackupEntity = NamedTuple("BackupIdent", [("name", str), ("type", str)])
BackupGroupsByEntity = Dict[BackupEntity, List[BackupItem]]
SizeStats = NamedTuple("SizeStats", [("path", str), ("size", int), ("objects", int)])


class JugglerStatus(str, Enum):
    OK: str = 'OK'
    WARN: str = 'WARN'
    CRIT: str = 'CRIT'


JugglerEvent = NamedTuple("JugglerEvent", [
    ("host", str),
    ("service", str),
    ("status", JugglerStatus),
    ("description", str),
    ("tags", List[str])
])


class BackupMonitoring(object):
    """Check s3 backups"""

    def __init__(
        self,
        policy: Policy,
        s3client: S3Client,
        skip_count_check: bool = False,
        skip_size_check: bool = False,
        skip_names: List[str] = list(),
        max_size_diff_up: int = 50,
        max_size_diff_down: int = 10,
        now: datetime = datetime.now(),
    ):
        self.policy = policy
        self.s3cli = s3client
        self.skip_count_check = skip_count_check
        self.skip_size_check = skip_size_check
        self.skip_names = skip_names
        self.max_size_diff_up = max_size_diff_up
        self.max_size_diff_down = max_size_diff_down
        self.now = now

    def run(self) -> List[JugglerEvent]:
        log.info("Run backup monitoring")
        backups = self.s3cli.list_backups()
        return self.check(backups)

    def group_by_db_name_and_type(self, backups: List[BackupItem]) -> BackupGroupsByEntity:
        groups: BackupGroupsByEntity = defaultdict(list)
        for bkp in backups:
            groups[BackupEntity(name=bkp.name, type=bkp.type)].append(bkp)
        return groups

    def check(self, backups: List[BackupItem]) -> List[JugglerEvent]:
        events: List[JugglerEvent] = list()
        groups_by_entity = self.group_by_db_name_and_type(backups)
        tags: List[str] = self.build_juggler_tags(self.s3cli.bucket.name)

        for entity, backups_for_entity in groups_by_entity.items():
            log.info(f"Check counts for {entity}")
            sorted_backups = sorted(backups_for_entity, key=lambda x: x.timestamp, reverse=True)
            last = sorted_backups[0]
            log.info(f"Last backup {last}")
            last_completed = self.s3cli.completed(last)
            prev = None
            if len(sorted_backups) > 1:
                prev = sorted_backups[1]

            count_by_policy = sum(self.policy._asdict().values())
            missing_count = count_by_policy - len(backups_for_entity)
            status = JugglerStatus.OK
            desc = f"{status}"
            delta = self.now - last.timestamp
            if delta > timedelta(hours=EXPEC_NEXT_BACKUP_AFTER_HOURS):
                status = JugglerStatus.CRIT
                desc = f"Last backup too old {last}"
            elif not self.skip_count_check and missing_count > 0:
                status = JugglerStatus.CRIT
                desc = f"Missing {missing_count} backups"
            elif not last_completed and prev and not self.s3cli.completed(prev):
                status = JugglerStatus.CRIT
                desc = f"Last backup in progress, but prev({prev}) not completed"

            if status == JugglerStatus.OK and last_completed:
                # if the last backup is complete
                status, desc = self.check_size(last, backups_for_entity)

            e = JugglerEvent(
                host=f"{self.s3cli.bucket.name}-s3-backups",
                service=f"{entity.type}/{entity.name}",
                status=status,
                description=desc,
                tags=tags,
            )
            log.info(f"Add: {e}")
            events.append(e)
        return events

    def check_size(self, last: BackupItem, backups: List[BackupItem]) -> Tuple[JugglerStatus, str]:
        log.info(f"Check size for {last.type}/{last.name}")
        status = JugglerStatus.OK
        desc = f"{status}"

        if len(backups) < 3:
            log.info(f"Skip size check for {last.name} by backups count: {len(backups)}")
            return status, desc

        if self.skip_size_check:
            if not self.skip_names or last.name in self.skip_names:
                log.info(f"Skip size check for {last.name} by skip_names={self.skip_names}")
                return status, desc

        stats: List[SizeStats] = []
        for bkp in backups:
            size, objects = self.s3cli.du(bkp)
            stats.append(SizeStats(bkp.path, size, objects))
        stats = sorted(stats, key=lambda x: x.size)
        smallest = stats[0]
        biggest = stats[-1]

        if last.path == smallest.path:
            abnormal = smallest
        elif last.path == biggest.path:
            abnormal = biggest
        else:
            return status, desc

        sum_size = sum(s.size for s in stats)
        size_percent = abnormal.size / (sum_size / len(stats) / 100.0)
        size_diff = size_percent - 100
        log.info(f"Last backup {abnormal} size diff {size_diff}%")

        if size_diff < 0 and abs(size_diff) > self.max_size_diff_down:
            status = JugglerStatus.CRIT
            desc = f"Last backup too small {size_diff:0.1f}%  of the avg"
        elif size_diff > self.max_size_diff_up:
            status = JugglerStatus.CRIT
            desc = f"Last backup too large +{size_diff:0.1f}% of the avg"

        return status, desc

    def send_events_to_juggler(self, events: List[JugglerEvent]) -> int:
        reply = requests.post(
            JUGGLER_PUSH_API,
            json={
                "source": BINARY_TASK_ATTR_TARGET,
                "events": [e._asdict() for e in events],
            },
            headers={'User-Agent': f"sandbox/projects/{BINARY_TASK_ATTR_TARGET}"}
        )
        status = reply.json()["events"]
        errors_event_status = [e for e in status if e["code"] != 200]
        if errors_event_status:
            raise Exception(json.dumps(errors_event_status))
        return len(status)

    @staticmethod
    def build_juggler_tags(bucket) -> List[str]:
        return [
            "bucket_" + bucket,
            "sanbox_task_" + BINARY_TASK_ATTR_TARGET.split("/")[-1],
        ]
