# -*- coding: utf-8 -*-
"""
Created on Apr 19, 2016

@author: noob
"""

from collections import OrderedDict
from datetime import datetime
from common.util.clients import ClickhouseClient, escape_string
from monitoring.models import Metric
from common.models import Server

class Aggregator(object):
    datetime_formats = (
        "%Y%m%d%H%M%S",
        "%Y-%m-%d %H:%M:%S",
        "%Y-%m-%dT%H:%M:%S",
    )

    def __init__(self, job_obj=None, tag='', start=None, end=None):
        self.job_obj = job_obj
        self.tag = tag

        if isinstance(start, int) or isinstance(start, float):
            self.start = int(start)
        elif isinstance(start, str) or isinstance(start, str):
            for df in self.datetime_formats:
                try:
                    start = str(datetime.strptime(str(start), df))
                    break
                except ValueError:
                    continue
            self.start = start
        else:
            self.start = start

        if isinstance(end, int) or isinstance(end, float):
            self.end = int(end)
        elif isinstance(end, str) or isinstance(end, str):
            for df in self.datetime_formats:
                try:
                    end = str(datetime.strptime(str(end), df))
                    break
                except ValueError:
                    continue
            self.end = end
        else:
            self.end = end

        self.ch_client = ClickhouseClient()

    @staticmethod
    def count_percentage(count_data):
        """
        returns list
        :param count_data: list or gen of count
        """
        count_data = list(count_data)
        summa = float(sum(count_data))
        percent_values = [round(float(value) * 100 / summa, 3)
                          for value in count_data]
        return percent_values

    @staticmethod
    def count_quantiles(data):
        """
        returns list of percentiles in ascending order
        :param data: list or gen of count for bins sorted in ascending order
                            (bins sorted, not counts);
        """
        data = list(data)
        summa = sum(data)
        quantiles_data = [round(float(value) * 100 / summa, 3)
                          for value in [sum(data[0:i]) for i in range(1, len(data) + 1)]]
        return quantiles_data


class RTDetailsAggregator(Aggregator):
    def aggregate(self):
        result = OrderedDict()
        for param in list(self.params_sql_mapping.keys()):
            values = self.get_param_data(param)
            result[param] = {
                'average': values[0],
                'stddev': values[1],
                'minimum': values[2],
                'maximum': values[3],
                'median': values[4],
            }
        return result

    def get_param_data(self, param):
        query_params = self.job_obj.basic_query_params.copy()
        query_params['param'] = self.params_sql_mapping[param]
        if param in ('resps', 'threads', 'input', 'output'):
            sql = '''
                select round(avg({param}),3),
                round(stddevPop({param}),3),
                round(min({param}),3),
                round(max({param}),3),
                round(median({param}),3)
                from loaddb.rt_microsecond_details_buffer 
                where job_id={job}
                and job_date=toDate({job_date})
            '''
        else:
            sql = '''
                select round(avg({param})/1000,3),
                round(stddevPop({param})/1000,3),
                round(min({param})/1000,3),
                round(max({param}/1000),3),
                round(median({param})/1000,3)
                from loaddb.rt_microsecond_details_buffer 
                where job_id={job}
                and job_date=toDate({job_date})
            '''
        if self.job_obj.multitag and self.tag:
            tag = self.tag
            cases_with_tag = ["'{}'".format(escape_string(str(case)))
                              for case in self.job_obj.cases if tag in case.split('|')]
            if cases_with_tag:
                sql += '''
                    and tag in ({cases_with_tag}) 
                    '''
                query_params['cases_with_tag'] = ','.join(cases_with_tag)
            else:
                sql += '''
                    and tag='{tag}' 
                    '''
                query_params['tag'] = self.tag
        else:
            sql += '''
                and tag='{tag}' 
                '''
            query_params['tag'] = self.tag
        if self.start:
            sql += '''
                and time >= toDateTime('{start}')
                '''
            query_params['start'] = self.start
        if self.end:
            sql += '''
                and time <= toDateTime('{end}')
                '''
            query_params['end'] = self.end
        if param not in ['threads', 'resps']:
            sql += ' and resps!=0'

        data = self.ch_client.select(sql, query_params)
        data = data[0] if data else [0] * 5
        return data

    @property
    def params_sql_mapping(self):
        sql_map = OrderedDict()
        sql_map['resps'] = 'resps'
        sql_map['threads'] = 'threads'
        sql_map['expect'] = '(connect_time_sum + send_time_sum + latency_sum + receive_time_sum)/resps'
        sql_map['connect_time'] = 'connect_time_sum/resps'
        sql_map['send_time'] = 'send_time_sum/resps'
        sql_map['latency'] = 'latency_sum/resps'
        sql_map['receive_time'] = 'receive_time_sum/resps'
        sql_map['input'] = 'igress'
        sql_map['output'] = 'egress'

        return sql_map


class ProtoCodesAggregator(Aggregator):
    def aggregate(self):
        result = OrderedDict()
        data = self.get_raw_data()
        percentage = self.count_percentage(count[1] for count in data)
        i = 0
        while i < len(data):
            result[data[i][0]] = {
                'count': data[i][1],
                'percent': percentage[i]
            }
            i += 1

        return result

    def get_raw_data(self):
        """
        Forms sql request
        gets data
        """
        query_params = self.job_obj.basic_query_params.copy()
        query_params.update({'param': self.param, 'sql_table': self.sql_table})

        sql = '''select {param}, toFloat32(sum(cnt))
            from {sql_table}
            where job_id={job} 
            and job_date=toDate({job_date})
            '''
        if self.job_obj.multitag and self.tag:
            tag = self.tag
            cases_with_tag = ["'{}'".format(escape_string(str(case)))
                              for case in self.job_obj.cases if tag in case.split('|')]
            if cases_with_tag:
                sql += ''' and tag in ({cases_with_tag}) '''
                query_params['cases_with_tag'] = ','.join(cases_with_tag)
            else:
                sql += '''
                    and tag='{tag}'  
                    '''
                query_params['tag'] = self.tag
        else:
            sql += '''
                and tag='{tag}'  
                '''
            query_params['tag'] = self.tag
        if self.start:
            sql += '''
                and time >= toDateTime('{start}')
                '''
            query_params['start'] = self.start
        if self.end:
            sql += '''
                and time <= toDateTime('{end}')
                '''
            query_params['end'] = self.end
        sql += '''
            group by {param}
            order by {param}'''

        raw_data = self.ch_client.select(sql, query_params)
        return raw_data

    @property
    def sql_table(self):
        return 'loaddb.proto_codes_buffer'

    @property
    def param(self):
        return 'code'


class NetCodesAggregator(ProtoCodesAggregator):
    @property
    def sql_table(self):
        return 'loaddb.net_codes_buffer'

    @property
    def param(self):
        return 'code'


class RTHistogramsAggregator(ProtoCodesAggregator):
    def aggregate(self):
        result = OrderedDict()
        data = self.get_raw_data()
        percentage = self.count_percentage([count[1] for count in data])
        quantiles = self.count_quantiles([count[1] for count in data])

        i = 0
        while i < len(data):
            result[data[i][0]] = {
                'count': data[i][1],
                'percent': percentage[i],
                'quantile': quantiles[i],
            }
            i += 1

        return result

    @property
    def sql_table(self):
        return 'loaddb.rt_microsecond_histograms_buffer'

    @property
    def param(self):
        return 'bin/1000'


class MonitoringAggregator(Aggregator):
    def aggregate(self):
        result = OrderedDict()
        data = self.get_raw_data()
        i = 0
        while i < len(data):
            item = data[i]
            target = item[0]
            if target not in list(result.keys()):
                result[target] = OrderedDict()
            result[target][item[1]] = {
                'average': item[2],
                'stddev': item[3],
                'minimum': item[4],
                'maximum': item[5],
                'median': item[6],
            }
            i += 1
        return result

    def get_raw_data(self, target=None, metric=None):
        query_params = self.job_obj.basic_query_params.copy()
        query_params['target'] = Server.objects.get(n=int(target)).host if target else None
        query_params['metric'] = Metric.objects.get(id=int(metric)).code if metric else None

        query = '''select
            target_host,
            metric_name,
            round(avg(value), 3),
            round(stddevPop(value), 3),
            round(min(value), 3),
            round(max(value), 3),
            round(median(value), 3)
            from loaddb.monitoring_verbose_data_buffer
            where job_id={job}
            and job_date=toDate({job_date})
            '''
        if target:
            query += " and target_host='{target}'"
        if metric:
            query += " and metric_name='{metric}'"
        if self.start:
            query += '''
                and time >= toDateTime('{start}')
                '''
            query_params['start'] = self.start
        if self.end:
            query += '''
                and time <= toDateTime('{end}')
                '''
            query_params['end'] = self.end
        query += '''
           group by target_host, metric_name
           order by target_host, metric_name
           '''
        return self.ch_client.select(query, query_params)
