# encoding: utf-8

import glob
import yaml
import os.path
import logging

from jinja2 import Template

from travel.devops.solomon.utils import build_arcadia_path


ENVS = ['prod', 'testing']

log = logging.getLogger(__name__)


class SolomonAlerts:
    def __init__(self, args, root_path, solomon_client):
        self.args = args
        self.client = solomon_client
        self.object_ids = self.get_object_ids()
        self.path = os.path.join(root_path, 'alerts')

    def get_object_ids(self):
        data = self.client.read_all()
        if not data:
            log.info("Alerts: there are no resources of this type")
            return []
        object_ids = []
        for datum in data['items']:
            object_ids.append(datum['id'])
        return object_ids

    def drop_unwanted_fields(self, alert):
        self.client.drop_system_fields(alert)
        alert.pop('state', None)
        # Also drop outdated fields
        alert.pop('notificationChannels', None)
        alert.pop('periodMillis', None)
        alert.pop('delaySeconds', None)

    @staticmethod
    def sort_fields(alert):
        alert['channels'] = sorted(alert['channels'], key=lambda x: x['id'])

    def write_single_alert(self, alert):
        alert_id = alert.pop('id')
        self.sort_fields(alert)
        if alert_id in self.object_ids:
            alert_old = self.client.read(alert_id)
            version = alert_old.get('version')
            self.drop_unwanted_fields(alert_old)
            self.sort_fields(alert_old)
            if alert_old == alert:
                log.info(f"Alert {alert_id} not changed")
                return
            self.client.diff_data(alert_old, alert, alert_id)
            log.info(f"*** Alert {alert_id} changed, updating")
            alert['version'] = version
            self.client.update(alert, alert_id)
            log.info(f"Alert {alert_id} successfully updated")
        else:
            log.info(f"Alert {alert_id} not found. Creating new alert")
            self.client.create(alert, alert_id)
            log.info(f"Alert {alert_id} successfully created")

    def merge_dicts(self, d1, d2):
        for k, v2 in d2.items():
            v1 = d1.get(k)
            if v1 is not None and isinstance(v1, dict):
                self.merge_dicts(v1, v2)
            else:
                d1[k] = v2

    def read_basic_alert(self, alert):
        if 'inherited_from' not in alert:
            return {}
        basic_alert_path = build_arcadia_path(alert.pop('inherited_from'))
        with open(basic_alert_path, 'r') as f:
            basic_alert = yaml.load(f, Loader=yaml.FullLoader)
            self.convert_fields(basic_alert)
            return basic_alert

    def get_alert_for_env(self, alert_name, env, default=None):
        result = {}
        alert = self.read_alert_for_env(alert_name, env, default=default)
        basic_alert = self.read_basic_alert(alert)
        self.merge_dicts(result, basic_alert)
        self.merge_dicts(result, alert)
        return result

    def read_alert(self, alert_name, env):
        local_base_dir = os.path.dirname(alert_name)
        local_base = os.path.join(local_base_dir, '_base')
        result = {}
        self.merge_dicts(result, self.get_alert_for_env('_base', ''))
        self.merge_dicts(result, self.get_alert_for_env('_base', env, {}))
        self.merge_dicts(result, self.get_alert_for_env(local_base, '', {}))
        self.merge_dicts(result, self.get_alert_for_env(local_base, env, {}))
        self.merge_dicts(result, self.get_alert_for_env(alert_name, ''))
        self.merge_dicts(result, self.get_alert_for_env(alert_name, env, {}))
        return result

    def read_alert_for_env(self, alert_name, env, default=None):
        fn = alert_name
        if env:
            fn += '.' + env
        fn += '.yaml'
        path = os.path.join(self.path, fn)
        if os.path.exists(path):
            with open(path, 'rt') as f:
                alert = yaml.load(f, Loader=yaml.FullLoader)
                self.convert_fields(alert)
                return alert
        else:
            if default is not None:
                return default
            raise Exception('Alert not found by path %s' % path)

    @staticmethod
    def list_to_string(data, joiner=''):
        if isinstance(data, list):
            return joiner.join(data)
        if isinstance(data, str):
            return data
        raise Exception("Unknown data type in var: %s" % data)

    def convert_alert_program(self, alert, variables, alert_type, program_attr, joiner):
        if alert['type'].get(alert_type) is None:
            return

        program = alert['type'][alert_type][program_attr]
        program = self.list_to_string(program, joiner)
        program_template = Template(program)
        alert['type'][alert_type][program_attr] = program_template.render(**variables)

    @staticmethod
    def substitute_alert_vars(alert, variables, fields_path):
        curr = alert
        for path_path in fields_path[:-1]:
            curr = curr[path_path]
        program_template = Template(curr[fields_path[-1]])
        curr[fields_path[-1]] = program_template.render(**variables)

    @staticmethod
    def convert_fields(alert):
        period_millis = alert.pop('periodMillis', None)
        if period_millis is not None:
            alert['windowSecs'] = int(period_millis / 1000)
        period_sec = alert.pop('periodSeconds', None)
        if period_sec is not None:
            alert['windowSecs'] = period_sec
        period_min = alert.pop('periodMinutes', None)
        if period_min is not None:
            alert['windowSecs'] = period_min * 60
        period_hr = alert.pop('periodHours', None)
        if period_hr is not None:
            alert['windowSecs'] = period_hr * 60 * 60
        # Convert alert format
        if 'notificationChannels' in alert:
            channels = list()
            for channel in alert.pop('notificationChannels'):
                if isinstance(channel, str):
                    channels.append({'id': channel, 'config': {}})
                else:
                    channels.append({
                        'id': channel['id'],
                        'config': channel.get('config', {})
                    })
            alert['channels'] = channels
        if 'delaySeconds' in alert:
            alert['delaySecs'] = alert.pop('delaySeconds')

    def fill_basic_data(self, alert, alert_name, env):
        if not alert.get('description'):
            raise Exception("Alert %s has no description" % alert_name)
        alert_path = os.path.dirname(alert_name)
        alert_name_plain = alert_name.replace('/', '-')
        full_name = alert_name_plain
        juggler_service_suffix = ''
        if alert.pop('env_suffix', True):
            full_name = full_name + '_' + env
            juggler_service_suffix = '-' + env
        alert['id'] = full_name
        alert['name'] = full_name
        variables_orig = alert.pop('_program_vars', dict())
        variables = dict()
        for k, v in variables_orig.items():
            if isinstance(v, list):
                variables[k] = self.list_to_string(v)
            else:
                variables[k] = v
        variables['ENV'] = env
        self.convert_alert_program(alert, variables, 'expression', 'program', '\n')
        self.convert_alert_program(alert, variables, 'threshold', 'selectors', '')
        self.substitute_alert_vars(alert, variables, ['description'])
        service_add = alert.pop('_service_add', '')
        host = alert.pop('_host', 'cluster')
        if 'groupByLabels' in alert and (service_add == '' and host == 'cluster'):
            raise Exception("When using groupByLabels (multialerts) you should also use _service_add or _host")
        alert['annotations']['env'] = env
        alert['annotations']['juggler_service'] = f'{alert_name_plain}{service_add}{juggler_service_suffix}'
        alert['annotations']['service'] = f'{alert_path}{service_add}'
        alert['annotations']['host'] = host
        action_items = alert.pop('actionItems', None)
        if action_items:
            alert['annotations']['action_items'] = '\n'.join(action_items)

    @staticmethod
    def is_env_specific(alert_name):
        for env in ENVS:
            if alert_name.endswith('.' + env):
                return True
        return False

    def iterate_alert_files(self):
        for f in glob.iglob(os.path.join(self.path, '**/*.yaml'), recursive=True):
            yield f

    def push(self):
        if not self.args.env:
            self.args.env = ENVS

        if not self.args.alert:
            for f in self.iterate_alert_files():
                f = os.path.relpath(f, self.path)
                fn = os.path.basename(f)
                if fn.startswith('_base.'):
                    continue
                if fn.endswith('.template.yaml'):
                    continue
                name = f[:-5]
                if not self.is_env_specific(name):
                    self.args.alert.append(name)

        for alert_name in self.args.alert:
            for env in self.args.env:
                alert = self.read_alert(alert_name, env)
                only_env = alert.pop("only_env", None)
                if only_env is not None and only_env != env:
                    continue
                self.fill_basic_data(alert, alert_name, env)
                self.write_single_alert(alert)

    def run(self):
        getattr(self, self.args.command)()
