#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import os
import re
import socket
import time
import calendar
from datetime import date, timedelta, datetime

import numpy as np
import requests
from dateutil import parser as dtparser
from dateutil import tz
from requests.adapters import HTTPAdapter

import yt.wrapper as yt


# TODO: copied from supervisor
class AppControl(object):
    YT_APP_CONTROL_PATH = "//home/yabs/stat/AppControl"

    def __init__(self, ytc):
        self._app_control = {}
        if not ytc.exists(self.YT_APP_CONTROL_PATH):
            pass
        else:
            for k, v in ytc.get(self.YT_APP_CONTROL_PATH).items():
                v = v.lower()
                if v in ('enabled', 'paused', 'disabled'):
                    self._app_control[k] = v
                else:
                    logging.warn("Wrong app control value '{}' for '{}'".format(v, k))
                    # Treat wrong value as 'paused' as the most safe and recoverable state
                    self._app_control[k] = 'paused'

    def is_collector_can_run(self, collector_name):
        return self._app_control.get(collector_name, 'enabled') == 'enabled'

    def is_collector_enabled(self, collector_name):
        return self._app_control.get(collector_name, 'enabled') != 'disabled'


class LogsMonitorExtractor(object):
    """
        Класс для получения данных о логах из
        Yt (LogProcessingStatus и logs/yabs-rt)
        Работает с данными моложе одной недели
    """

    TABLE_PATH = "//home/yabs/stat/LogProcessingStatus"
    FOLDER_PATH = "//home/yabs/log/yabs-rt"
    DICTS_PATH = "//home/yabs/dict"
    DICTS_PREPROD_PATH = "//home/yabs/dict/preprod"
    TIME_ATTR_NAME = "/@max_unix_time"
    REPLICAS_ATTR = "/@replicas"
    REPLICATED_TABLES_PATH = "//home/yabs/stat/replica"
    MAX_INTERVALS = 1000  # Максимальное число интервалов выбираемых из окна (обычно около 3х)
    INPUT_ROW_LIMIT = 10 * 1000 * 1000
    LOGFELLER_BATCH_TIME = {'1min': 60, '5min': 5 * 60, '1h': 60 * 60, '1d': 24 * 60 * 60}

    def __init__(self, cluster_url, token):
        """
            Конструктор создает конекшен к Yt
        """
        self.cluster_url = cluster_url
        cfg = {
            "tabular_data_format": yt.JsonFormat(control_attributes_mode="row_fields"),
            "detached": False,
            "token": token,
            "proxy": {"url": cluster_url},
        }
        self.cli = yt.YtClient(config=cfg)
        self.markov_cli = yt.YtClient("markov", token)
        self.pythia_cli = yt.YtClient("pythia", token)
        twb = date.today() - timedelta(days=3)
        self.three_days_ago = int(time.mktime(twb.timetuple()))

    def get_logs_count(self):
        """
            Метод считает логи в папке и возвращает в
            виде пар {имя_лога : кол-во штук}
        """
        log_types = self.cli.list(self.FOLDER_PATH) if self.cli.exists(self.FOLDER_PATH) else []
        return {log_type: self.cli.get("{}/{}/@count".format(self.FOLDER_PATH, log_type))
                for log_type in log_types}

    def get_logtypes_collectors_map(self):
        log_types = self.cli.list(self.FOLDER_PATH) if self.cli.exists(self.FOLDER_PATH) else []
        result = {}
        for log_type in log_types:
            collectors = self.cli.select_rows(
                """
                    CollectorName
                from
                    [{}]
                where
                        UpdateTime > {}u
                    and
                       is_substr("{}", LogName)
                group by
                    CollectorName
                """.format(
                    self.TABLE_PATH,
                    self.three_days_ago,
                    log_type
                ),
                input_row_limit=self.INPUT_ROW_LIMIT
            )
            result[log_type] = [row["CollectorName"] for row in collectors]
        return result

    def _get_proc_time_done(self, log_type=None):
        """
            Метод возвращает время обработки последнего
            лога для каждого типа коллекторов у соответствующего
            типа лога
        """
        latest_done = self.cli.select_rows(
            """
                CollectorName,
                max(UpdateTime << 32u | if(is_null(StartTime), 0u, StartTime)) as mixtime
            from
                [{}]
            where
                    IsDone=1u
                and
                    UpdateTime > {}u
                and
                    is_substr("{}", LogName)
            group by
                CollectorName
            """.format(
                self.TABLE_PATH,
                self.three_days_ago,
                log_type if log_type else ""  # если не передан log_type то по всем типам
            ),
            input_row_limit=self.INPUT_ROW_LIMIT
        )
        result = {}
        for rec in latest_done:
            name = rec["CollectorName"]
            name = name.replace(".", "_")  # for legacy comp
            mixed_time = rec["mixtime"]
            start_time = mixed_time & ((1 << 32) - 1)  # unpack dates
            update_time = mixed_time >> 32
            result[name] = update_time - start_time  # сохраняем разницу
        return result

    def _get_proc_time_undone(self, max_output_rows=None):
        """
            Метод возвращает максимальное время
            обработки лога для каждого из незавершившихся
            коллекторов
        """
        all_undone = self.cli.select_rows(
            """
                CollectorName,
                UpdateTime
            from
                [{}]
            where
                IsDone=0u
            """.format(self.TABLE_PATH),
            input_row_limit=self.INPUT_ROW_LIMIT,
            output_row_limit=max_output_rows
        ) if self.cli.exists(self.TABLE_PATH) else []
        app_control = AppControl(self.cli)
        result = {}
        cur_time = int(time.time())
        for record in all_undone:
            if not app_control.is_collector_enabled(record["CollectorName"]):
                continue
            name = record["CollectorName"]
            name = name.replace(".", "_")  # for legacy compatibility
            uptime = record["UpdateTime"]
            result[name] = cur_time - uptime  # сохраняем разницу
        return result

    def get_logtype_row_count(self):
        """
            Метод получает количество строк
            для каждого типа(sic!) логов.
        """
        log_types = self.cli.list(self.FOLDER_PATH) if self.cli.exists(self.FOLDER_PATH) else []
        result = {}
        for log_type in log_types:
            node_path = "{}/{}".format(self.FOLDER_PATH, log_type)
            chunks = self.cli.list(node_path)
            total_rows = 0
            for chunk in chunks:
                try:
                    total_rows += self.cli.row_count(node_path + "/" + chunk)
                except yt.YtError:
                    logging.exception('Failed to get row count of chunk %s', chunk)
            result[log_type] = total_rows
        return result

    def _get_logtype_collectors_in_window(self, log_type, window_start, window_end):
        """
            Метод получает записи по всем коллекторам в заданном интервале
        """
        working_window = self.cli.select_rows(
            """
                CollectorName,
                StartTime,
                UpdateTime,
                IsDone
            from
                [{}]
            where
                    is_substr("{}", LogName)
                and
                    StartTime > {}u
                and
                    UpdateTime < {}u
                order by
                    StartTime
                limit {}
            """.format(
                self.TABLE_PATH,
                log_type,
                window_start,
                window_end,
                self.MAX_INTERVALS
            ),
            input_row_limit=self.INPUT_ROW_LIMIT
        )
        return working_window

    def get_load_in_window(self, intervals, window_size):
        """
            Считает загрузку в окне. Интервалы строго не пересекающиеся
                  t_обработки
            t_l = -----------
                  t_окна
        """
        working_time = 0
        for interval in intervals:
            start_time = interval[0]
            update_time = interval[1]
            working_time += update_time - start_time
        return float(working_time) / window_size

    def get_logtypes_collectors_rough_load(self, window_size):
        """
            Функция рассчитывает значения загрузки для
            каждого логтайпа и каждого коллектора
            возвращает пару словарей -- для лог тайпов и для коллекторов
        """
        window_end = int(time.time())
        window_start = window_end - window_size
        log_types = self.cli.list(self.FOLDER_PATH) if self.cli.exists(self.FOLDER_PATH) else []
        lt_res = {}
        collectors_res = {}
        for log_type in log_types:  # по каждому логтайпу

            log_type_collectors = self._get_logtype_collectors_in_window(log_type, window_start, window_end)  # получаем все записи коллекторы
            merged_intervals = []
            collectors_dict = {}

            for row in log_type_collectors:
                start_time = row["StartTime"]
                update_time = row["UpdateTime"]
                collect_name = row["CollectorName"]

                if not row["IsDone"]:
                    logging.info("Collector {} currently working.".format(collect_name))
                    update_time = window_end

                merged_intervals.append((start_time, update_time))  # накапливаем все интервалы

                if collect_name not in collectors_dict:  # сохраняем группированные по коллекторам,
                    collectors_dict[collect_name] = []  # чтобы не делать лишних запросов
                collectors_dict[collect_name].append((start_time, update_time))

            united_intervals = self._union_intervals(merged_intervals)  # мерджим интервалы логтайпa
            lt_res[log_type] = self.get_load_in_window(united_intervals, window_size)  # получаем загрузку по логтайпу

            for collector in collectors_dict:  # получаем загрузку по каждому из коллекторов
                collectors_res[collector] = self.get_load_in_window(collectors_dict[collector], window_size)

        return lt_res, collectors_res  # возвращаем словарь по всем логтайпам и по коллекторам

    def get_logtypes_collectors_proc_time(self, max_output_rows=None):
        """
            Метод возвращает максимальное время обработки для
            каждого логтайпа и каждого коллектора
        """
        log_types = self.cli.list(self.FOLDER_PATH) if self.cli.exists(self.FOLDER_PATH) else []
        all_undone = self._get_proc_time_undone(max_output_rows=max_output_rows)
        lt_res = {}
        collectors_res = {}
        for log_type in log_types:  # по всем логтайпам
            log_type_collectors = self._get_proc_time_done(log_type=log_type)  # получаем последние разницы
            if not log_type_collectors:
                continue

            for collector in log_type_collectors:
                if collector in all_undone:  # если коллектор сейчас работает берем максимум из разниц
                    log_type_collectors[collector] = max(log_type_collectors[collector], all_undone[collector])

            lt_res[log_type] = max(log_type_collectors.values())  # сохраняем для логтайпа
            collectors_res.update(log_type_collectors)  # сохраняем для коллекторов

        return lt_res, collectors_res

    def _union_intervals(self, intervals):
        """
            Объединяет пересекающиеся интервалы в
            один интервал
            intervals -- массив пар (начало, конец)
        """
        result = []
        if intervals:
            cur_start = intervals[0][0]
            cur_end = intervals[0][1]
            for interv in intervals[1:]:
                if interv[0] > cur_end:
                    result.append((cur_start, cur_end))
                    cur_start = interv[0]
                    cur_end = interv[1]
                elif cur_end < interv[1]:
                    cur_end = interv[1]
            result.append((cur_start, cur_end))
        return result

    def get_logfeller_times(self, path):
        result = {}
        for log_type, logs in self.cli.get('{}&'.format(path), attributes=['target_path']).iteritems():
            log_data_times = []
            for log in logs.itervalues():
                if 'target_path' in log.attributes:
                    parts = log.attributes['target_path'].split('/')
                    if len(parts) >= 2 and parts[-2] in self.LOGFELLER_BATCH_TIME:
                        time_start = dtparser.parse(parts[-1]).replace(tzinfo=tz.tzlocal())
                        time_end = time_start + timedelta(0, self.LOGFELLER_BATCH_TIME[parts[-2]])
                        log_data_times.append((time_start, time_end))
            result[log_type] = log_data_times
        return result

    def get_logtypes_reader_lag(self):
        """Для каждого типа логов возвращает возраст данных в самом свежем батче"""
        cur_time = int(time.time())
        result = {}
        if not self.cli.exists(self.FOLDER_PATH):
            return result
        for log_type, logs in self.cli.get(self.FOLDER_PATH, attributes=['max_unix_time']).iteritems():
            log_data_times = [
                log.attributes['max_unix_time']
                for log in logs.itervalues()
                if 'max_unix_time' in log.attributes
            ]
            if log_data_times:
                result[log_type] = cur_time - max(log_data_times)

        for log_type, log_times in self.get_logfeller_times(self.FOLDER_PATH).iteritems():
            if log_type in result:
                continue
            log_data_times = [calendar.timegm(end_time.utctimetuple()) for _, end_time in log_times]
            result[log_type] = cur_time - max(log_data_times) if log_data_times else 0  # no tables => log delay is zero
        return result

    def get_logtypes_oldest_log_age(self):
        """Для каждого типа логов возвращает возраст самого старого батча.
        Для симлинок смотрит на возраст симлинки, а не на возраст лога.
        """
        cur_time = datetime.now(tz.tzutc())
        result = {}
        if not self.cli.exists(self.FOLDER_PATH):
            return result
        for log_type, log_times in self.get_logfeller_times(self.FOLDER_PATH).iteritems():
            if log_times:
                log_data_times = [start_time for start_time, _ in log_times]
                result[log_type] = int((cur_time - min(log_data_times)).total_seconds())

        for log_type, logs in self.cli.get(self.FOLDER_PATH, attributes=['creation_time']).iteritems():
            if log_type in result:
                continue
            # for symlinks it returns symlink creation time, not target creation time
            log_creation_times = [log.attributes['creation_time'] for log in logs.itervalues()]
            if not log_creation_times:
                result[log_type] = 0  # no tables => log age is zero
            else:
                oldest_table_time = dtparser.parse(min(log_creation_times))
                log_age = cur_time - oldest_table_time
                result[log_type] = int(log_age.total_seconds())
        return result

    def _check_dict(self, dpath):
        try:
            return self.cli.exists(dpath + self.TIME_ATTR_NAME)
        except yt.errors.YtHttpResponseError as e:
            if e.is_access_denied():
                logging.warn(e.error['message'])
                return False
            raise e

    def get_replicated_dicts_lag(self, dicts_path):
        """
        Возвращает словарь Справочник -> max_unix_time
        Для таблиц на приемниках (динамические таблицы-справочники)
        """
        all_dicts = self.cli.list(dicts_path) if self.cli.exists(dicts_path) else []
        logging.info("Found {} dicts".format(len(all_dicts)))

        if dicts_path == self.DICTS_PATH:
            master_cli = self.markov_cli
        elif dicts_path == self.DICTS_PREPROD_PATH:
            master_cli = self.pythia_cli
        else:
            raise ValueError("Wrong dict path {}".format(dicts_path))

        result = {}
        for dpath in all_dicts:
            full_path = os.path.join(dicts_path, dpath)
            logging.info("Checking dict {}".format(full_path))
            if self._check_dict(full_path) and master_cli.exists(self.DICTS_PATH + '/' + dpath + self.REPLICAS_ATTR):
                logging.info("{} is dict with attr {}".format(full_path, self.TIME_ATTR_NAME))
                result[dpath] = int(time.time()) - self.cli.get_attribute(full_path, "max_unix_time")
            else:
                logging.info("Dict {} doesn't pass check".format(full_path))
        return result

    def get_replicated_tables_lag(self):
        """
        Возвращает словарь Справочник -> replication_lag_time (в секундах)
        для таблиц на реплицирующем кластере (таблицы-репликаторы)
        """
        result = {}
        if not self.cli.exists(self.REPLICATED_TABLES_PATH):
            return result
        replicated_tables = self.cli.list(self.REPLICATED_TABLES_PATH)
        logging.info("Found candidates {} tables on cluster {}".format(len(replicated_tables), self.cluster_url))
        for dpath in replicated_tables:
            full_path = os.path.join(self.REPLICATED_TABLES_PATH, dpath)
            logging.info("Checking table {}".format(full_path))
            if self.cli.exists(full_path + self.REPLICAS_ATTR):
                logging.info("Table {} replicated".format(full_path))
                replicas = self.cli.get(full_path + self.REPLICAS_ATTR)
                for _, replica_info in replicas.iteritems():
                    if "mode" not in replica_info or replica_info["mode"] == "async":
                        result[(dpath, replica_info["cluster_name"])] = replica_info["replication_lag_time"] / 1000
            else:
                logging.info("Table {} is not replicated".format(full_path))
        return result


class GraphiteWorker(object):
    """
        Класс для отправки данных в графит.
        Пока костыльный хардкод...
    """
    GRAPHITE_HOSTS = [
        "mega-graphite-man.search.yandex.net:2024",
        "mega-graphite-sas.search.yandex.net:2024"
    ]

    ONE_SEC = "one_sec"
    FIVE_SEC = "five_sec"
    TEN_SEC = "ten_sec"
    ONE_MIN = "one_min"
    FIVE_MIN = "five_min"
    TEN_MIN = "ten_min"
    ONE_HOUR = "one_hour"
    ONE_DAY = "one_day"

    GRAPHITE_URL = "http://bs-mg.yandex-team.ru/render/?from=-{}s&until=-0s&target={}&format=json"

    def __init__(self):
        self.seen_metrics = {}

    def fix_collector_name(self, name):
        """
            Метод приводит имя коллектора к формату
            понимаемому графитом:
            заменяет все не буквоцифры на _
            отбрасывает последний символ если он
            не буквоцифра
        """
        # отщипываем всем не буквоцифры с конца
        while not name[-1].isalnum():
            name = name[:-1]
        return re.sub('[^0-9a-zA-Z-.]+', '_', name)

    def append(self, period, machine_name, prefix_name, collector_name, metric_name, point, timestamp=None):
        """
            Послать точку в графит с указанием периода
            если не указан timestamp отправляет с текущей датой
        """
        if collector_name:
            path = '.'.join(["{}"]*5).format(
                period,
                machine_name.replace(".", "_"),
                prefix_name,
                self.fix_collector_name(collector_name),
                metric_name
            )
        else:
            path = '.'.join(["{}"]*4).format(
                period,
                machine_name.replace(".", "_"),
                prefix_name,
                metric_name
            )
        tm = None
        if timestamp is None:
            tm = int(time.time())
        else:
            tm = timestamp

        data = " ".join([path, str(point), str(tm)])
        if path not in self.seen_metrics:
            self.seen_metrics[path] = (point, data)
        if self.seen_metrics[path][0] <= point:
            self.seen_metrics[path] = (max(self.seen_metrics[path][0], point), data)

    def push(self):
        rows = "\n".join(sorted(x[1] for x in self.seen_metrics.values())) + "\n"
        print rows
        for server in self.GRAPHITE_HOSTS:
            (host, port) = server.split(":")
            try:
                sock = socket.create_connection((host, port), 5)  # 5 sec timeout
                sock.sendall(rows)
                sock.close()
            except socket.error as msg:
                logging.warn("Send metrics to {} failed with error: {}".format(server, msg))
        self.seen_metrics = {}

    def request(self, expression, time_back=300):
        url = self.GRAPHITE_URL.format(time_back, expression)
        result = requests.get(url, HTTPAdapter(max_retries=5))
        if result.ok:
            return result.json()
        else:
            return {}

    def get_graphite_points_with_interp(self, expression, target_pos, time_back=300):
        raw_data = self.request(expression, time_back)
        result = {}
        for data_dict in raw_data:
            target = data_dict["target"].split('.')[target_pos]
            datapoints = [(p[1], p[0]) for p in data_dict["datapoints"] if p[0] is not None]
            last_ts = data_dict["datapoints"][-1][1]
            xs, ys = zip(*datapoints)
            if target in result and datapoints:
                print "Collision for target {} while fetching from graphite".format(target)
                result[target] = max(result[target], int(np.interp(last_ts + 5, xs, ys)))
            elif not datapoints:
                result[target] = None
            else:
                result[target] = int(np.interp(last_ts + 5, xs, ys))
        return result
