import logging
import json
import urllib2
import time
import argparse
from datetime import datetime

import statface_client
from statface_client import StatfaceClientReportError


def try_request(url, json_data=None, retries=3, wait_seconds=10, timeout=60, exception=False):
    if not url.startswith('http://') and not url.startswith('https://'):
        url = 'http://' + url
    for retrie in range(retries):
        try:
            request = urllib2.Request(url)
            if json_data:
                request.add_header('Content-Type', 'application/json')
                result = urllib2.urlopen(request, json.dumps(json_data), timeout=timeout)
            else:
                result = urllib2.urlopen(request, timeout=timeout)
        except urllib2.HTTPError, e:
            logging.error(url + ' ' + str(e))
            logging.error(e.read())
        except urllib2.URLError, e:
            logging.error(url + ' ' + str(e))
        except Exception, e:
            logging.error(url + ' ' + str(e))
        else:
            return result.read()
        time.sleep(wait_seconds)
    if exception:
        raise Exception('Cannot get data from: {url}'.format(url=url))
    return None


def time_format(timestamp):
    return str(datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S'))


class Service(object):

    def __init__(self, ctype, name, info, config):
        self.ctype = ctype
        self.name = name
        self.weight = info.get('service_weight')
        self.id = str(self)
        self.warnings = []
        self.critical = []

        self._add_signals(self.warnings, info, config.get('warnings', []))
        self._add_signals(self.critical, info, config.get('critical', []))

    def __str__(self):
        return '{name} {ctype}'.format(name=self.name, ctype=self.ctype)

    def _add_signals(self, signals, service_info, patterns):
        for pattern in patterns:
            threshold = service_info.get(pattern['label'])
            if threshold:
                signal = pattern['signal'].format(ctype=self.ctype, service=self.name)
                signals.append({'signal': signal, 'threshold': threshold, 'label': pattern['label']})

    def _threshold_exceeded(self, conditions, ts, signals_values, status, exceedances_list):
        result = False
        for condition in conditions:
            if signals_values.get(condition['signal']) >= condition['threshold']:
                result = True
                logging.debug('{time} {status} {service} : {condition} {value} >= {threshold}'.format(
                    time=time_format(ts),
                    status=status,
                    service=self.id,
                    condition=condition['label'],
                    value=signals_values.get(condition['signal']),
                    threshold=condition['threshold']
                ))
                if exceedances_list is not None:
                    if self.id not in exceedances_list:
                        exceedances_list[self.id] = {}
                    if condition['label'] not in exceedances_list[self.id]:
                        exceedances_list[self.id][condition['label']] = {'threshold': condition['threshold'], 'count': 0}
                    exceedances_list[self.id][condition['label']]['count'] += 1
        return result

    def in_warning(self, ts, signals_values, warnings_list=None):
        return self._threshold_exceeded(self.warnings, ts, signals_values, 'warning', warnings_list)

    def in_critical(self, ts, signals_values, criticals_list=None):
        return self._threshold_exceeded(self.critical, ts, signals_values, 'critical', criticals_list)

    def get_signals(self):
        signals = set([condition['signal'] for condition in self.warnings])
        signals.update([condition['signal'] for condition in self.critical])
        return list(signals)


class SignalsApi(object):
    MAX_SIGNALS_COUNT = 100
    DATA_LIFETIME_IN_SEC = 30 * 24 * 60 * 60

    def __init__(self, host='ASEARCH', period=300, attempts=5):
        self.host = host
        self.period = period
        self.attempts = attempts

    def _values(self, start, end, signals):
        from yasmapi import GolovanRequest

        for attempt in range(self.attempts):
            result = {}
            try:
                for ts, values in GolovanRequest(self.host, self.period, start, end - 1, signals, explicit_fail=True):
                    result[ts] = values
                return result
            except Exception as e:
                logging.error('golovan exception: {}'.format(e))
        raise Exception('Cannot get data form golovan')

    def values(self, start, end, signals):
        signals_count = len(signals)
        if len(signals) <= self.MAX_SIGNALS_COUNT:
            return self._values(start, end, signals)

        result = {}
        for subset_begin in range(0, signals_count, self.MAX_SIGNALS_COUNT):
            subset_end = min(subset_begin + self.MAX_SIGNALS_COUNT, signals_count)
            signals_subset = signals[subset_begin:subset_end]
            subset_result = self._values(start, end, signals_subset)
            for ts in subset_result:
                if not result.get(ts):
                    result[ts] = {}
                result[ts].update(subset_result[ts])
        return result


class Statistica(object):

    class ReportConfig(object):
        def __init__(self, path):
            self.impl = {
                'user_config': {
                    'dimensions': [{'fielddate': 'date'}],
                    'measures': [],
                    'view_types': {}
                },
                'title': path.split('/').pop(),
            }

        def add_measure(self, name, number_type='Float'):
            self.impl['user_config']['measures'].append({name: 'number'})
            self.impl['user_config']['view_types'][name] = {
                'type': number_type,
                'precision': 5
            }

    def __init__(self, config):
        host = statface_client.STATFACE_PRODUCTION
        if config.get('cluster', 'prod') == 'beta':
            host = statface_client.STATFACE_BETA
        self.__client = statface_client.StatfaceClient(host=host, oauth_token=config['token'])

    def get_report(self, path, config=None):
        try:
            if not config:
                return self.__client.get_old_report(path)
            else:
                report = self.__client.get_new_report(path)
                report.upload_config(config=config.impl, overwrite=False)
                return report
        except StatfaceClientReportError, e:
            logging.error('StatfaceClientReportError: {}'.format(e))
        return False

    def write(self, report, data):
        if data:
            report.upload_data(scale='hourly', data=data)

    def get_last_timestamp(self, report):
        timestamp = 0
        for _ in range(5):
            try:
                data = report.download_data(scale='hourly', _period_distance=0)
                return data.pop(0).get('fielddate__ms') // 1000
            except Exception as e:
                logging.error('Cannot get last timestamp: {}'.format(e))
                time.sleep(60)
        return timestamp


class SaaSMetrics(object):

    def __init__(self, config=None, config_file=None):
        if not config and not config_file:
            raise Exception('config or config file must be specified')
        if config_file:
            with open(config_file) as config:
                config = json.load(config)
        self.config = config
        self.__signals_api = SignalsApi()
        self.__statistica = Statistica(self.config['statistica'])

    def _get_services(self, metrica_config):
        services_by_ctype = metrica_config.get('services', None)

        if 'services_file' in metrica_config:
            logging.info('Getting services from file')
            with open(metrica_config['services_file']) as services_file_data:
                services_by_ctype = json.load(services_file_data)

        if 'services_request' in metrica_config:
            logging.info('Getting services from request')
            if metrica_config['services_request'].find('{ctype}') == -1:
                result = try_request(metrica_config['services_request'], timeout=60*10, exception=True)
                services_by_ctype = json.loads(result)
            else:
                services_by_ctype = {}
                for ctype in self.config['ctypes']:
                    request = metrica_config['services_request'].format(ctype=ctype)
                    result = try_request(request, timeout=60*10, exception=True)
                    services_by_ctype[ctype] = json.loads(result)

        if not services_by_ctype:
            raise Exception('Cannot get services data.')
        logging.info('Received services.')

        services = []
        for ctype, services_data in services_by_ctype.items():
            for name, info in services_data.items():
                if info.get('service_weight'):
                    logging.debug('Add service: {}'.format(name))
                    services.append(Service(ctype, name, info, metrica_config))
                else:
                    logging.debug('Skip service with zero weight: {}'.format(name))
        return services

    def _get_required_range(self, report, estimated_range):
        start = estimated_range.get('from', 0)
        if not self.config.get('recalc'):
            start = max(start, self.__statistica.get_last_timestamp(report))
        start = max(start, int(time.time()) - self.__signals_api.DATA_LIFETIME_IN_SEC + 60 * 60)
        start -= start % (60 * 60)

        end = int(time.time()) - 25 * 60
        if estimated_range.get('to'):
            end = min(estimated_range.get('to'), end)

        hours = max(0, (end - start) // (60 * 60))
        logging.info('Required range: form {} to {}. Hours: {}'.format(time_format(start), time_format(end), hours))
        return start, hours

    def get_sla_report_config(self, path):
        config = Statistica.ReportConfig(path)
        config.add_measure('warning')
        config.add_measure('critical')
        return config

    def get_unanswers_report_config(self, path):
        config = Statistica.ReportConfig(path)
        for ctype in self.config['ctypes']:
            config.add_measure(ctype, 'Percent')
        return config

    def get_report(self, path, config_getter=None):
        report = self.__statistica.get_report(path)
        if not report and config_getter:
            report = self.__statistica.get_report(path, config=config_getter(path))
        if not report:
            raise Exception('Cannot get report: %s' % path)
        return report

    def show_sla_results(self, exceedances_list, status, count_all):
        logging.info('Services in {status} {count}:'.format(status=status, count=len(exceedances_list)))
        for service in exceedances_list:
            logging.info('\t{service}:'.format(service=service))
            for label, info in exceedances_list[service].items():
                logging.info('\t\t{label} should be less then {threshold}, violated {count} out of {count_all} times'.format(
                    label=label,
                    threshold=info['threshold'],
                    count=info['count'],
                    count_all=count_all
                ))

    def sla(self):
        logging.info('Calculation of the metric: sla')
        metrica_config = self.config['metrics'].get('sla')
        if not metrica_config:
            raise Exception('No config for metrica: sla')

        report = self.get_report(metrica_config['report'], self.get_sla_report_config)
        start, hours = self._get_required_range(report, self.config.get('range', {}))

        services = self._get_services(metrica_config)
        signals = []
        total_weight = 0.0
        for service in services:
            signals.extend(service.get_signals())
            total_weight += service.weight

        report_data = []
        warnings_list = {}
        criticals_list = {}
        for hour in range(hours):
            left = start + hour * 60 * 60
            right = left + 60 * 60
            row = {'fielddate': time_format(right), 'warning': 0.0, 'critical': 0.0}

            signals_data = self.__signals_api.values(left, right, signals)
            for ts, values in signals_data.items():
                for service in services:
                    if service.in_warning(ts, values, warnings_list):
                        row['warning'] += service.weight
                    if service.in_critical(ts, values, criticals_list):
                        row['critical'] += service.weight

            row['warning'] /= total_weight * 12
            row['critical'] /= total_weight * 12
            report_data.append(row)
            logging.debug(row)

        self.__statistica.write(report, report_data)

        self.show_sla_results(warnings_list, 'warning', hours * 12)
        self.show_sla_results(criticals_list, 'critical', hours * 12)

    def unanswers(self):
        logging.info('Calculation of the metric: unanswers')
        metrica_config = self.config['metrics'].get('unanswers')
        if not metrica_config:
            raise Exception('No config for metrica: unanswers')

        signals = [metrica_config['signal'].format(ctype=ctype) for ctype in self.config['ctypes']]

        report = self.get_report(metrica_config['report'], self.get_unanswers_report_config)
        start, hours = self._get_required_range(report, self.config.get('range', {}))

        report_data = []
        for hour in range(hours):
            left = start + hour * 60 * 60
            right = left + 60 * 60
            row = {'fielddate': time_format(right)}
            signals_data = self.__signals_api.values(left, right, signals)
            for signals_values in signals_data.values():
                for index, ctype in enumerate(self.config['ctypes']):
                    row[ctype] = max(row.get(ctype, 0.0), signals_values[signals[index]])
            report_data.append(row)
            logging.debug(row)

        self.__statistica.write(report, report_data)

    def calculate(self):
        for metrica_name in self.config['metrics']:
            try:
                metrica = getattr(self, metrica_name)
            except:
                raise Exception('Unknown metrica: {}'.format(metrica))
            metrica()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', default='config.json')
    args = parser.parse_args()

    metrics = SaaSMetrics(config_file=args.config)
    metrics.calculate()
