import logging
import re
from collections import OrderedDict, defaultdict
from copy import deepcopy, copy
from datetime import datetime, timedelta

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.common.types.task import Status as TaskStatus
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.projects.yabs.qa.solomon.mixin import SolomonTaskMixinParameters, init_push_client
from sandbox.projects.common.binary_task import LastBinaryTaskRelease, LastBinaryReleaseParameters
from sandbox.projects.yabs.monitoring.tasks.CollectYtUsageStatistics.lib.operations import (
    TIME_FORMAT,
    OperationState,
    schedule_operation,
    init_tmp_table,
    iter_operations_info,
)
from sandbox.projects.yabs.monitoring.tasks.CollectYtUsageStatistics.lib.aggregation import (
    AggregationKey,
    get_bucket,
)
from sandbox.projects.yabs.monitoring.tasks.CollectYtUsageStatistics.lib.tables import get_list_of_tables
from sandbox.projects.yabs.monitoring.tasks.CollectYtUsageStatistics.lib.statistics import iter_metrics
from sandbox.projects.yabs.monitoring.tasks.CollectYtUsageStatistics.lib.main import get_aggregated_operations
from sandbox.projects.yabs.monitoring.tasks.CollectYtUsageStatistics.lib.report import create_html_report


YT_CLUSTERS = [
    "hahn",
    "arnold",
    "freud",
]


def parse_human_readable(td_str):
    pattern = re.compile(r'(?P<hours>\d+):(?P<minutes>\d+):(?P<seconds>\d+)')
    m = pattern.match(td_str)
    if m is None:
        raise TypeError("String '{}' has insupported format".format(td_str))
    m_dict = m.groupdict()
    return int(m_dict['hours']) * 60 * 60 + int(m_dict['minutes']) * 60 + int(m_dict['seconds'])


class TimeDelta(sdk2.parameters.String):
    @classmethod
    def cast(cls, value):
        if isinstance(value, basestring):
            try:
                return int(value)
            except Exception:
                return parse_human_readable(value)
        elif isinstance(value, (int, float)):
            return value
        raise TypeError("Unable to parse {} {} as timedelta".format(type(value), value))


class CollectYtUsageStatistics(LastBinaryTaskRelease, sdk2.Task):

    class Parameters(LastBinaryReleaseParameters):
        solomon_parameters = SolomonTaskMixinParameters()

        with sdk2.parameters.Group("Report parameters") as report_parameters:
            create_report = sdk2.parameters.Bool("Write report to task's info", default=True)

        with sdk2.parameters.Group("Statistics") as statistics_parameters:
            timespan_type = sdk2.parameters.String("Timespan type", default="from_now", choices=[(c, c) for c in ("from_now", "explicit", "stream")])
            with timespan_type.value["explicit"]:
                timespan = sdk2.parameters.String("Time interval to gather statistics from", description="Format {f}/{f}".format(f=TIME_FORMAT))
            with timespan_type.value["from_now"]:
                interval = TimeDelta("Time interval to gather statistics from", default="01:00:00", description="Format HH:MM:SS")
            with timespan_type.value["stream"]:
                timespan_limit = TimeDelta("Limit timespan size", default="24:00:00", description="Format HH:MM:SS")
            aggregation_interval = TimeDelta("Interval for operations aggregation", default="00:05:00", description="Format HH:MM:SS")

        with sdk2.parameters.Group("Filters") as filters_parameters:
            users = sdk2.parameters.List("Users to filter operations", default=["robot-yabs-cs-sb", "robot-yabs-cs-sbjail"])
            yt_cluster = sdk2.parameters.String("Cluster to filter operations", default="hahn", choices=[(cluster, cluster) for cluster in YT_CLUSTERS])
            with sdk2.parameters.CheckGroup("YT Clusters to filter operations", default=YT_CLUSTERS) as yt_clusters:
                for cluster_name in YT_CLUSTERS:
                    yt_clusters.values[cluster_name] = cluster_name
            with sdk2.parameters.CheckGroup("Operation states", default=[OperationState.COMPLETED, OperationState.FAILED]) as operation_states:
                choices = [
                    OperationState.COMPLETED,
                    OperationState.FAILED,
                    OperationState.ABORTED,
                ]
                for choice in choices:
                    operation_states.values[choice] = choice
            with sdk2.parameters.CheckGroup("Aggregate by", default=[AggregationKey.Sandbox.TASK_TYPE, AggregationKey.Yt.OPERATION_STATE]) as aggr_by:
                choices = [
                    AggregationKey.Sandbox.BIN_DB_LIST,
                    AggregationKey.Sandbox.TASK_TYPE,
                    AggregationKey.Sandbox.RUN_TYPE,
                    AggregationKey.Yt.OPERATION_STATE,
                    AggregationKey.Yt.CLUSTER,
                    AggregationKey.Yt.POOL,
                    AggregationKey.Yt.Annotations.SERVANT_NAME,
                    AggregationKey.Yt.Annotations.CONTENT_SYSTEM_KEY,
                    AggregationKey.Yt.Annotations.TAGS,
                ]
                for choice in choices:
                    aggr_by.values[choice] = choice

        with sdk2.parameters.Group("Misc") as misc_parameters:
            yt_token_vault_name = sdk2.parameters.String('Vault name for YT token', required=True, default='yabs-cs-sb-yt-token')

        with sdk2.parameters.Output:
            out_from_time = sdk2.parameters.Integer("Result timespan begin")
            out_to_time = sdk2.parameters.Integer("Result timespan end")

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 1024
        environments = (
            PipEnvironment('yandex-yt', use_wheel=True),
        )

        class Caches(sdk2.Requirements.Caches):
            pass

    def get_timespan(self):
        timespan_type = self.Parameters.timespan_type
        if timespan_type == "explicit":
            from_time_str, to_time_str = map(lambda x: x.strip(), self.Parameters.timespan.split("/"))
            from_time = datetime.strptime(from_time_str, TIME_FORMAT)
            to_time = datetime.strptime(to_time_str, TIME_FORMAT)
        elif timespan_type == "from_now":
            to_time = datetime.utcnow()
            from_time = to_time - timedelta(seconds=self.Parameters.interval)
        elif timespan_type == "stream":
            to_time = datetime.utcnow()
            if self.scheduler is not None:
                last_succeeded_task = sdk2.Task.find(scheduler=self.scheduler, status=TaskStatus.SUCCESS, order="-id", limit=1).first()
                if not last_succeeded_task:
                    raise TaskFailure("No succeeded tasks found for scheduler {}".format(self.scheduler))
                from_time = datetime.utcfromtimestamp(last_succeeded_task.Parameters.out_to_time)
                to_limit = from_time + timedelta(seconds=self.Parameters.timespan_limit)
                if to_time > to_limit:
                    logging.warning("Too big timespan: %s, will use %s", to_time - from_time, to_limit)
                    to_time = to_limit
            else:
                raise TaskFailure("Timespan type {} should be used only from scheduler".format(timespan_type))
        else:
            raise TaskFailure("Unexpected timespan type {}".format(timespan_type))
        return from_time, to_time

    def on_execute(self):
        from yt.wrapper import YtClient

        yt_token = sdk2.Vault.data(self.Parameters.yt_token_vault_name)
        ytc_hahn = YtClient(proxy="hahn", token=yt_token)

        from_time, to_time = self.get_timespan()
        self.Parameters.out_from_time = (from_time - datetime.utcfromtimestamp(0)).total_seconds()

        logging.info("Filtering operations from %s to %s", from_time, to_time)
        self.set_info("Filtering operations from {} to {}".format(from_time, to_time))

        tables = get_list_of_tables(from_time, to_time, ytc_hahn)

        if not tables:
            raise TaskFailure("No tables found in //logs/yt-scheduler-log for specified time interval")

        filter_operations = {
            "users": self.Parameters.users,
            "from_time": from_time,
            "to_time": to_time,
            "cluster_names": self.Parameters.yt_clusters,
            "log_tables": tables,
            "operation_states": self.Parameters.operation_states,
        }
        logging.info("Filter: %s", filter_operations)

        tmp_table = init_tmp_table(ytc_hahn, "//home/yabs-cs-sandbox/monitoring/op_stats/tmp_dir", self.id)
        op = schedule_operation(ytc_hahn, tmp_table, **filter_operations)

        self.set_info("Operation <a href=\"{}\" target=\"_blank\">{}</a> scheduled".format(op.url, op.id), do_escape=False)

        operations_by_finish_time = defaultdict(list)
        all_operations = []
        buckets = list(range(
            int((from_time - datetime.utcfromtimestamp(0)).total_seconds()),
            int((to_time - datetime.utcfromtimestamp(0)).total_seconds()),
            self.Parameters.aggregation_interval))
        logging.debug("Buckets are: %s", buckets)
        last_operation_finish = None
        first_operation_finish = None

        for operation in iter_operations_info(ytc_hahn, op, tmp_table):
            finish_time = datetime.strptime(operation["finish_time"], TIME_FORMAT)
            bucket = get_bucket(buckets, (finish_time - datetime.utcfromtimestamp(0)).total_seconds())
            if bucket is None:
                logging.warning("Operation %s has finish_time %s greater than the rightmost boundary, therefore will be skipped", operation["operation_id"], finish_time)
                continue
            logging.debug("Found bucket %s for %s", datetime.utcfromtimestamp(bucket), finish_time)
            if self.Parameters.push_to_solomon:
                operations_by_finish_time[bucket].append(operation)
            if self.Parameters.create_report:
                all_operations.append(operation)
            last_operation_finish = max(last_operation_finish or finish_time, finish_time)
            first_operation_finish = min(first_operation_finish or finish_time, finish_time)

        logging.info("Got operations between %s and %s", first_operation_finish, last_operation_finish)

        self.Parameters.out_to_time = max(operations_by_finish_time.keys() + [(from_time - datetime.utcfromtimestamp(0)).total_seconds()])

        if self.Parameters.create_report:
            report_aggregated_operations = get_aggregated_operations(all_operations, self.Parameters.aggr_by, sandbox_client=self.server)
            report_rows = []
            header = copy(self.Parameters.aggr_by)
            for op_key, ops_info_list in report_aggregated_operations.items():
                logging.debug("Operation key=%s", op_key)
                labels = OrderedDict([
                    (_aggr_key_name, op_key[i])
                    for i, _aggr_key_name in enumerate(self.Parameters.aggr_by)
                ])
                report_row = deepcopy(labels)
                for metric_labels, value in iter_metrics(ops_info_list):
                    report_row[metric_labels["sensor"]] = value
                    if metric_labels["sensor"] not in header:
                        header.append(metric_labels["sensor"])
                report_rows.append(report_row)

            self.set_info(create_html_report(report_rows, header, from_time=first_operation_finish, to_time=last_operation_finish), do_escape=False)

        if self.Parameters.push_to_solomon:
            sensors = []
            solomon_push_client = init_push_client(self)
            for timestamp, operations in operations_by_finish_time.items():
                aggregated_operations = get_aggregated_operations(operations, self.Parameters.aggr_by, sandbox_client=self.server)

                report_rows = []
                for op_key, ops_info_list in aggregated_operations.items():
                    logging.debug("Operation key=%s", op_key)
                    labels = OrderedDict([
                        (_aggr_key_name, op_key[i])
                        for i, _aggr_key_name in enumerate(self.Parameters.aggr_by)
                    ])
                    for metric_labels, value in iter_metrics(ops_info_list):
                        _labels = deepcopy(labels)
                        _labels.update(metric_labels)
                        sensors.append({
                            "labels": _labels,
                            "value": value,
                            "ts": timestamp,
                        })

            solomon_push_client.add(sensors)
            solomon_push_client.push_collected()
