from datetime import datetime, timedelta
import logging
import os
import time

from sandbox.sandboxsdk import environments
from sandbox import sdk2
from sandbox.projects.cloud.analytics.common.utils import last_rounded_ts, grouper_it, get_last_table, ISO_FORMAT
from sandbox.projects.cloud.analytics.common.analytics_task import AnalyticsTask


class CloudSolomonToYT(AnalyticsTask):
    """ Task to import data from solomon to YT for cloud analytics"""
    BATCH_SIZE = 1
    SOLOMON_PROJECT = 'yandexcloud'
    INTERVAL = 60 * 60  # 1 hour
    AGGREGATES = (
        'min',
        'max',
        'last',
        'avg',
        'sum',
    )
    YT_PREFIX = '//home/cloud_analytics/import/resources/1h'
    SOLOMON_CLUSTER = 'cloud_prod_scheduler'
    SOLOMON_SERVICE = 'resources'
    SOLOMON_METRICS = '*cores*|memory_*|nvme_*'
    DEFAULT_FROM_TS = datetime(datetime.utcnow().date().year,  # Start of today
                                datetime.utcnow().date().month,
                                datetime.utcnow().date().day)

    class Requirements(AnalyticsTask.Requirements):
        # TODO(syndicut): Use prebuilt wheels here
        environments = (
            environments.PipEnvironment('yandex-yt'),
            environments.PipEnvironment('yandex-yt-yson-bindings-skynet'),
            environments.PipEnvironment('requests'),
            environments.PipEnvironment('futures'),
            environments.PipEnvironment('requests_futures', version='0.9.9'),  # https://st.yandex-team.ru/CLOUDPS-502
        )

    def on_execute(self):
        import yt.wrapper as yt
        from sandbox.projects.cloud.analytics.common.solomon.client import SolomonAPI, SolomonDataAPI
        from concurrent.futures import as_completed

        def parse_solomon(project, cluster, service, metric, node_names, aggregate, from_ts, to_ts, solomon_data_api):
            begin = from_ts.strftime(SolomonDataAPI.TIME_FORMAT)
            end = to_ts.strftime(SolomonDataAPI.TIME_FORMAT)

            futures = []
            for batch in grouper_it(self.BATCH_SIZE, node_names):
                node_name = '|'.join(batch)
                extra_params = {
                    'l.zone_id': '*',
                    'l.metric': metric,
                    'l.node_name': node_name,
                    'l.host_group': 'all|-',
                    'b': begin,
                    'e': end,
                    'points': 1,
                    'downsamplingAggr': aggregate,
                }

                future = solomon_data_api.get(project, cluster, service, extra_params)
                futures.append(future)

            for future in as_completed(futures):
                response = future.result()
                data = response.data
                for sensor in data['sensors']:
                    values = sensor['values']
                    if len(values) > 0:
                        ret = sensor['labels']
                        ret[aggregate] = sensor['values'][0]['value']
                        ret['start'] = time.mktime(from_ts.timetuple())
                        ret['end'] = time.mktime(to_ts.timetuple())
                        yield ret

        yt.config['token'] = sdk2.Vault.data(self.owner, "robot-clanalytics-yt-yt-token")
        yt.config['proxy']['url'] = 'hahn'
        # yt.config['proxy']['request_retry_count'] = 12

        solomon_api = SolomonAPI(sdk2.Vault.data(self.owner, "robot-clanalytics-yt-solomon-token"))
        solomon_data_api = SolomonDataAPI('robot-clanalytics-yt', use_futures=True)

        node_names = [name
                      for name in solomon_api.labels(self.SOLOMON_PROJECT, {'names': 'node_name', 'limit': 10000})['labels'][0]['values']
                      if name != 'all_nodes']
        logging.info('Got {} node names'.format(len(node_names)))
        for aggregate in self.AGGREGATES:
            table_path = '/'.join([self.YT_PREFIX, aggregate])
            last_table = get_last_table(yt, table_path)
            if last_table:
                from_ts = datetime.strptime(last_table, ISO_FORMAT)
            else:
                from_ts = self.DEFAULT_FROM_TS
            to_ts = from_ts + timedelta(seconds=self.INTERVAL)
            while to_ts <= last_rounded_ts(self.INTERVAL):
                table = '/'.join((table_path, to_ts.strftime(ISO_FORMAT)))
                logging.info('Importing table {} from {} to {}'.format(table, from_ts, to_ts))
                with yt.Transaction():
                    if not yt.exists(os.path.dirname(table)):
                        yt.create('map_node', os.path.dirname(table), recursive=True)
                    yt.write_table(yt.TablePath(table, append=True),
                               parse_solomon(self.SOLOMON_PROJECT,
                                             self.SOLOMON_CLUSTER,
                                             self.SOLOMON_SERVICE,
                                             self.SOLOMON_METRICS,
                                             node_names,
                                             aggregate,
                                             from_ts + timedelta(seconds=1),
                                             to_ts,
                                             solomon_data_api)
                               )
                from_ts, to_ts = to_ts, to_ts + timedelta(seconds=self.INTERVAL)
