#!/usr/bin/python

import argparse
import os
from datetime import datetime, timedelta

import boto3
import urllib3
from dateutil.tz import tzutc

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

SERVICE_NAME = 's3'
BASE_BUCKET = 'disk-backup-pg'
WALG_PREFIX = 'wal-e'
VERSION_MAP = {'906' : '9.6',
               '1000': '10'}
ALERT_DAYS = 2

class BackupMeta:
    def __init__(self, endpoint, key_id, key):
        self.session = boto3.session.Session(aws_access_key_id=key_id, aws_secret_access_key=key)
        self.s3_client = self.session.client(service_name=SERVICE_NAME, verify=False, endpoint_url=endpoint)
        self.ts_min = datetime.now(tzutc()) - timedelta(days=ALERT_DAYS)

    def list_shards(self):
        shards = []
        paginator = self.s3_client.get_paginator('list_objects')
        page_iterator = paginator.paginate(Bucket=BASE_BUCKET, Prefix=WALG_PREFIX + '/', MaxKeys=10000, Delimiter='/')
        for page in page_iterator:
            shards.extend([item['Prefix'].split('/')[1] for item in page['CommonPrefixes']])
        return shards

    def pg_versions(self, shard):
        s3_prefix = os.path.join(WALG_PREFIX, shard, '')
        response = self.s3_client.list_objects(Bucket=BASE_BUCKET, Prefix=s3_prefix, MaxKeys=10000, Delimiter='/')
        return set(item['Prefix'].split('/')[2] for item in response['CommonPrefixes'])

    def latest_pg_backup(self, shard):
        backup_dates = []
        for version in self.pg_versions(shard):
            s3_prefix = os.path.join(WALG_PREFIX, shard, version, 'basebackups_005', '')
            response = self.s3_client.list_objects(Bucket=BASE_BUCKET, Prefix=s3_prefix, MaxKeys=10000, Delimiter='/')
            sort_func = lambda i: i['LastModified']
            backup_dates.extend([(j['Key'], j['LastModified']) for j in sorted(response['Contents'], key=sort_func)])
        return backup_dates[-1]

    def current_pg_version(self, shard):
        meta_json_path, _ = self.latest_pg_backup(shard)
        return meta_json_path.split('/')[2]

    def backup_is_too_old(self, shard):
        _, ts = self.latest_pg_backup(shard)
        result = ts < self.ts_min
        return (result, shard, ts)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--key-id-file', default='/u0/backupcheck/wal-g/envdir/AWS_ACCESS_KEY_ID')
    parser.add_argument('--key-file', default='/u0/backupcheck/wal-g/envdir/AWS_SECRET_ACCESS_KEY')
    parser.add_argument('--s3-endpoint', default='https://s3.mds.yandex.net')
    parser.add_argument('--list-to-file', default=None)
    parser.add_argument('--alerts-to-file', default=None)
    parser.add_argument('--blacklist', default=None)
    args = parser.parse_args()

    blacklist = []

    s3_endpoint = args.s3_endpoint
    with open(args.key_id_file) as f:
        s3_access_key_id = f.read().strip()
    with open(args.key_file) as f:
        s3_secret_access_key = f.read().strip()

    backup_metainfo = BackupMeta(s3_endpoint, s3_access_key_id, s3_secret_access_key)

    if args.blacklist:
        with open(args.blacklist) as f:
            blacklist = [line.strip() for line in f.readlines()]

    shards = [shard for shard in  backup_metainfo.list_shards() if shard not in blacklist]

    if args.list_to_file:
        with open(args.list_to_file, 'w') as output_file:
            for shard in shards:
                version = VERSION_MAP[backup_metainfo.current_pg_version(shard)]
                output_file.write('{},{}\n'.format(shard, version))

    if args.alerts_to_file:
        with open(args.alerts_to_file, 'w') as output_file:
            for shard in shards:
                too_old, shard, ts = backup_metainfo.backup_is_too_old(shard)
                if too_old:
                    output_file.write('{},{}\n'.format(shard, ts.strftime('%Y-%m-%d')))
