#!/usr/bin/env python2.7
from collections import Counter
from datetime import datetime, timedelta
from functools import partial
from multiprocessing import Pool
from os import getenv

import boto3
from dateutil.tz import tzutc
from nile.api.v1.statface.client import StatfaceProductionClient
from nile.api.v1.statface.report import StatfaceReport


SERVICE_NAME = 's3'
ENDPOINT_URL = 'https://s3.mds.yandex.net'
BASE_BUCKET = 'disk-backup-pg'
WALG_PREFIX = 'wal-e'
AGE_LIMIT = 90


def count_backups_size(shard, service_name, endpoint, bucket, key_id, secret_key, walg_prefix):
    print('processing shard {}'.format(shard))
    session = boto3.session.Session(aws_access_key_id=key_id, aws_secret_access_key=secret_key)
    s3 = session.client(service_name=service_name, endpoint_url=endpoint)
    paginator = s3.get_paginator('list_objects')
    shard_path = '{}/{}'.format(walg_prefix, shard)
    page_iterator = paginator.paginate(Bucket=bucket, Prefix=shard_path, MaxKeys=10000)
    shard_backup_size = Counter()
    ts_min = datetime.now(tzutc()) - timedelta(days=AGE_LIMIT)

    for page in page_iterator:
        for item in page['Contents']:
            path = item['Key'].split('/')
            backup_type = path[3].split('_')[0]
            size = item['Size']
            ts = item['LastModified']
            # Calculating old files size
            if ts < ts_min:
                shard_backup_size['outdated'] += size
            # Calculating delta and full backups size
            if 'base' in backup_type and not path[4].endswith('.json'):
                if path[4].count('_') > 1:
                    base_subtype = 'delta'
                else:
                    base_subtype = 'full'
                shard_backup_size[base_subtype] += size
            # Calculating basebackups and wals size
            shard_backup_size[backup_type] += size

    return (shard, shard_backup_size)


def prepare_stat_data(fielddate, data):
    for shard, stat in data:
        yield dict(fielddate=fielddate, shard=shard, **stat)


def publish_stat(data, username, password):
    client = StatfaceProductionClient(username=username, password=password)
    report = StatfaceReport().path('Disk/Admin/PgS3Backups').scale('daily')
    report = report.client(client).data(data)
    report.publish()
    print 'Report published'


def print_stat(data):
    print('Size in GB')
    print('{:<26s} {:<14s} {:<14s} {:<14s} {:<14s} {:<14s} {:<14s}'.format('shard', 'basebackups', 'full', 'diff', 'wals', 'outdated', 'total'))
    for each in data:
        shard = each['shard']
        wals_size = float(each.get('wal', 0))/1024**3
        basebackups_size = float(each.get('basebackups', 0))/1024**3
        full_basebackups_size = float(each.get('full', 0))/1024**3
        diff_basebackups_size = float(each.get('diff',0))/1024**3
        total_size = wals_size + basebackups_size
        outdated_files_size = float(each.get('outdated',0))/1024**3
        print('{:<26} {:<14.2f} {:<14.2f} {:<14.2f} {:<14.2f} {:<14.2f} {:<14.2f}'.format(shard, basebackups_size, full_basebackups_size, diff_basebackups_size, wals_size, outdated_files_size, total_size))


if __name__ == '__main__':
    s3_access_key_id = getenv('S3_ACCESS_KEY_ID')
    s3_secret_access_key = getenv('S3_SECRET_ACCESS_KEY')
    stat_username = getenv('STAT_USER')
    stat_password = getenv('STAT_PASSWORD')
    stat_date = str(datetime.now().date())

    # Fetching list of shards
    print('Fetching list of shards')
    shards = []
    session = boto3.session.Session(aws_access_key_id=s3_access_key_id, aws_secret_access_key=s3_secret_access_key)
    s3 = session.client(service_name=SERVICE_NAME, endpoint_url=ENDPOINT_URL)
    paginator = s3.get_paginator('list_objects')
    page_iterator = paginator.paginate(Bucket=BASE_BUCKET, Prefix=WALG_PREFIX + '/', MaxKeys=10000, Delimiter='/')
    for page in page_iterator:
        for item in page['CommonPrefixes']:
            shards.append(item['Prefix'].split('/')[1])

    # Calculating info about backups
    print('Calculating info about backups')
    pool = Pool(24)
    worker = partial(count_backups_size, service_name=SERVICE_NAME, endpoint=ENDPOINT_URL, bucket=BASE_BUCKET,
                                        key_id=s3_access_key_id, secret_key=s3_secret_access_key, walg_prefix=WALG_PREFIX)
    results = pool.map(worker, shards)

    # prepairing and uploading results
    print('Prepairing and uploading results')
    data = prepare_stat_data(stat_date, results)
    publish_stat(data, stat_username, stat_password)

    pool.close()
    pool.join()
