# -*- coding: utf-8 -*-

import argparse
import itertools
import os
import parsedatetime as pdt
import requests
import sys
import time
import yaml

from datetime import datetime
from pathlib import Path
from retry import retry

import logging
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s %(name)s %(message)s")
logger = logging.getLogger(__name__)

YAML_NAME = 'dt.yaml'
ARCADIA_PATH = os.getenv('ARCADIA_PATH')
assert ARCADIA_PATH
JUGGLER_TOKEN = os.getenv('JUGGLER_OAUTH_TOKEN')
assert JUGGLER_TOKEN

YAML_PATH = Path(ARCADIA_PATH) / 'disk' / 'admin' / 'utils' / 'downtimer' / YAML_NAME

class Downtime:
    keys = 'namespace', 'host', 'service'
    groups = 'groups'
    def __init__(self, namespace='', host='', service=''):
        self.namespace = namespace
        self.host = host
        self.service = service

    def overload(self, other):
        for k in self.keys:
            att = other.__getattribute__(k)
            if att:
                self.__setattr__(k, att)
        return self

    def __eq__(self, other):
        return all(self.__getattribute__(att) == other.__getattribute__(att)
                   for att in self.keys)

    def __str__(self):
        return ', '.join(['%s: %s' % (k, self.__getattribute__(k)) for k in self.keys])

    def __repr__(self):
        return '(%s)' % ', '.join(['%s: %s' % (k, self.__getattribute__(k)) for k in self.keys])

    def format(self):
        dt = {k:self.__getattribute__(k) for k in self.keys}
        dt = {k:v for k, v in dt.items() if v}
        assert dt.get('namespace')
        assert len(dt) > 1
        return dt

def parse_dt(dt_dict, parent=Downtime()):
    '''
    Recursively parse dict, extends it with subgroups if they exists
    :param dt_dict:
    :param parent:
    :return:
    '''
    keys = {}
    for k in Downtime.keys:
        key = dt_dict.get(k, [''])
        if not isinstance(key, list):
            key = [key]
        keys[k] = key

    current_dts = [Downtime(*x).overload(parent) for x in itertools.product(*[keys[k] for k in Downtime.keys])]
    fin_dts = []
    if Downtime.groups in dt_dict.keys():
        for dt in current_dts:
            for group in dt_dict[Downtime.groups]:
                rec_dts = parse_dt(group, parent=dt)
                fin_dts.extend(rec_dts)
    else:
        fin_dts.extend(current_dts)

    return fin_dts

def parse_all_dts(dc, keys=(), to_json=True):
    yml = yaml.safe_load(open(YAML_PATH))
    downtimes = []
    if not keys:
        keys = yml.keys()
    for k, v in yml.items():
        if k in keys:
            downtimes.extend(parse_dt(v))

    for i, dt in enumerate(downtimes):
        if dt.host.startswith('%'):
            dt.host = 'CGROUP%s' % dt.host.replace('@dc', '@%s' % dc)
        if '@dc' in dt.service:
            dt.service = dt.service.replace('@dc', dc)
        downtimes[i] = dt

    if to_json:
        return [dt.format() for dt in downtimes]
    return downtimes

@retry(tries=3, delay=1, backoff=1)
def apply_one_dt(downtimes: list, end_time):
    #r = requests.get('<MY_URI>', headers={'Authorization': 'TOK:<MY_TOKEN>'})
    url = "http://juggler-api.search.yandex.net/v2/downtimes/set_downtimes"
    data = {
        "end_time": end_time,
        "filters": downtimes
    }
    headers = {
        "accept": "application/json",
        "Content-Type": "application/json",
        "Authorization": "OAuth %s" % JUGGLER_TOKEN
    }

    req = requests.post(url, json=data, headers=headers)
    if req.status_code != 200:
        raise Exception(req.text)

    logger.info('Downtime %s applied' % req.json()['downtime_id'])

def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def apply_dts(downtimes, end_time, chunk_size=50):
    for chunk in chunks(downtimes, chunk_size):
        apply_one_dt(chunk, end_time)

def parse_time(dt):
    c = pdt.Calendar()
    return time.mktime(c.parse(dt)[0])


def main():
    parser = argparse.ArgumentParser(description='Downtime disk checks during DC maintenance')
    parser.add_argument('dc', type=str, choices=['sas', 'iva', 'myt', 'vla', 'man'], help='Downtimed DC')
    parser.add_argument('time', type=str, help='DT duration or endtime (human friendly)')
    parser.add_argument('-t', '--test', action='store_true', required=False, help='Test current dt config')
    args = parser.parse_args()

    dts = parse_all_dts(args.dc)
    if args.test:
        for dt in dts:
            print(dt)
        print('Downtime until: %s' % datetime.fromtimestamp(parse_time(args.time)))
    else:
        apply_dts(dts, parse_time(args.time))


if __name__ == "__main__":
    main()


