from datetime import datetime, timedelta
from itertools import chain
import logging
import os
import time

from sandbox.sandboxsdk import environments
from sandbox import sdk2
from sandbox.projects.cloud.analytics.common.utils import last_rounded_ts, grouper_it, get_last_table, ISO_FORMAT
from sandbox.projects.cloud.analytics.common.analytics_task import AnalyticsTask

class CloudSolomonComputeToYT(AnalyticsTask):
    """ Task to import data about compute from solomon to YT for cloud analytics"""
    SOLOMON_PROJECT = 'yandexcloud'
    INTERVAL = 60 * 60  # 1 hour
    AGGREGATES = (
        'min',
        'max',
        'avg',
    )
    YT_PREFIX = '//home/cloud_analytics/import/solomon/compute'
    SOLOMON_CLUSTER = 'cloud_prod_compute'
    SOLOMON_SERVICE = 'compute'
    SOLOMON_COMPUTE_METRIC = 'cpu_usage|memory_usage'
    SOLOMON_DISK_METRIC = 'disk_*'
    DEFAULT_TO_TS = datetime(datetime.utcnow().date().year,  # Start of today
                             datetime.utcnow().date().month,
                             datetime.utcnow().date().day)
    DEFAULT_FROM_TS = DEFAULT_TO_TS - timedelta(days=1) # Start of yesterday
    QUERY = """
USE hahn;

$from_date = DateTime::FromSeconds(%FROM_TS%);
$to_date = DateTime::FromSeconds(%TO_TS%);

$callable = ($name) -> {
    $date = DateTime::MakeDate(DateTime::Parse("%Y-%m-%d")($name));
    return $date >= $from_date - DateTime::IntervalFromDays(1) AND $date <= $to_date + DateTime::IntervalFromDays(1);
};

INSERT INTO `%TMP_TABLE%` WITH TRUNCATE
SELECT DISTINCT(resource_id) FROM FILTER("//home/logfeller/logs/yc-billing-compute-instance/1d", $callable) WHERE _logfeller_timestamp >= %FROM_TS% AND _logfeller_timestamp < %TO_TS%;
"""

    class Parameters(AnalyticsTask.Parameters):
        batch_size = sdk2.parameters.Integer(
            'Size of instances batches for request compute metrics',
            default=20,
            required=True
        )

        page_size = sdk2.parameters.Integer(
            'Size of pages for request disk metrics',
            default=500,
            required=True
        )

        futures_size = sdk2.parameters.Integer(
            'Count of futures batch with solomon data request',
            default=10000,
            required=True
        )
  
        rows_size = sdk2.parameters.Integer(
            'Rows to push into yt database',
            default=75000,
            required=True
        )
 
    class Requirements(AnalyticsTask.Requirements):
        # TODO(syndicut): Use prebuilt wheels here
        environments = (
            environments.PipEnvironment('yandex-yt'),
            environments.PipEnvironment('yandex-yt-yson-bindings-skynet'),
            environments.PipEnvironment('requests'),
            environments.PipEnvironment('futures'),
            environments.PipEnvironment('requests_futures', version='0.9.9'),  # https://st.yandex-team.ru/CLOUDPS-502
        )

    def on_execute(self):
        import yt.wrapper as yt
        from sandbox.projects.cloud.analytics.common.solomon.client import SolomonAPI, SolomonDataAPI
        from concurrent.futures import as_completed

        def get_compute_data(project, cluster, service, instance_id, aggregate, begin, end, solomon_data_api):
            extra_params = {
                'l.metric': self.SOLOMON_COMPUTE_METRIC,
                'l.instance_id': instance_id,
                'b': begin,
                'e': end,
                'points': 1,
                'downsamplingAggr': aggregate,
            }
            yield solomon_data_api.get(project, cluster, service, extra_params)

        def get_disk_data(project, cluster, service, instance_id, aggregate, begin, end, solomon_api, solomon_data_api):
            futures = []

            selectors = {
                'cluster': cluster,
                'service': service,
                'instance_id': instance_id,
                'metric': self.SOLOMON_DISK_METRIC,
                'project_id': '-',
                'bucket': '-',
                'device': '*',
            }

            # get list of sensors for disk metrics
            sensors_params = {
                'selectors': ', '.join(["{}={}".format(k, v) for k, v in selectors.items()]),
                'forceCluster': 'vla',
                'pageSize': self.Parameters.page_size
            }
            page = 0
            pages_count = 1
            sensors = {}

            # prepare parameters for getting disk data with paging throw sensors and filter out some labels
            while page < pages_count:
                sensors_params['page'] = page
                info = solomon_api.sensors(self.SOLOMON_PROJECT, sensors_params)
                pages_count = info['page']['pagesCount']
                page+=1
                for result in info['result']:
                    labels = result['labels']
                    for k in ['cluster', 'service', 'project', 'metric']:
                        del labels[k]
                    # this triple is our primary key
                    key = "{instance_id}_{device}_{host}".format(**labels)
                    sensors[key] = labels

            # yield throw disk metrics that isn't histogram
            for sensor in sensors.values():
                extra_params = {
                    'b': begin,
                    'e': end,
                    'points': 1,
                    'downsamplingAggr': aggregate,
                    'l.metric': self.SOLOMON_DISK_METRIC,
                    'l.bucket': '-',
                }
                for k, v in sensor.items():
                    extra_params['l.{}'.format(k)] = v
                yield solomon_data_api.get(project, cluster, service, extra_params)

        def get_solomon_data(project, cluster, service, table, aggregate, from_ts, to_ts, solomon_api, solomon_data_api):
            begin = from_ts.strftime(SolomonDataAPI.TIME_FORMAT)
            end = to_ts.strftime(SolomonDataAPI.TIME_FORMAT)

            instance_ids = yt.read_table(yt.TablePath(table, columns=["resource_id"]))
            
            for batch in grouper_it(self.Parameters.batch_size, instance_ids):
                instance_id = '|'.join([row['resource_id'] for row in batch])
                yield get_compute_data(project, cluster, service, instance_id, aggregate, begin, end, solomon_data_api)
                yield get_disk_data(project, cluster, service, instance_id, aggregate, begin, end, solomon_api, solomon_data_api)
           
        def parse_solomon(futures, from_ts, to_ts):
            for future_batch in grouper_it(self.Parameters.futures_size, chain.from_iterable(futures)):
                for future in as_completed(list(future_batch)):
                    response = future.result()
                    data = response.data
                    for sensor in data['sensors']:
                       values = sensor['values']
                       if len(values) > 0:
                           ret = {
                                'selectors': ','.join(["{}={}".format(k,v) for k, v in sensor['labels'].items() if k not in ['instance_id', 'metric']]),
                                'resource_id':  sensor['labels']['instance_id'],
                                'metric': sensor['labels']['metric'],
                                'start': -int(time.mktime(from_ts.utctimetuple())),
                                'end': -int(time.mktime(to_ts.utctimetuple())),
                           }
                           ret[aggregate] = sensor['values'][0]['value']
                           yield ret

        yt.config['token'] = sdk2.Vault.data(self.owner, "robot-clanalytics-yt-yt-token")
        yt.config['proxy']['url'] = 'hahn'
        
        # yt.config['proxy']['request_retry_count'] = 12

        solomon_api = SolomonAPI(sdk2.Vault.data(self.owner, "robot-clanalytics-yt-solomon-token"))
        solomon_data_api = SolomonDataAPI('robot-clanalytics-yt', use_futures=True)

        # get instance identifiers from YT (prepare temporary table with instance ids)
        tmp_table = "{}_instance".format(self.YT_PREFIX)
        logging.info("Created temporary table {} for instance id".format(tmp_table))
        self._wait_for_subtask(
            "run_yql",
            "RUN_YQL_2",
            query=self.QUERY,
            trace_query=True,
            owner=self.Parameters.owner,
            publish_query=True,
            use_v1_syntax=True,
            custom_placeholders={
                '%TO_TS%': int(time.mktime(self.DEFAULT_TO_TS.utctimetuple())),
                '%FROM_TS%': int(time.mktime(self.DEFAULT_FROM_TS.utctimetuple())),
                '%TMP_TABLE%': tmp_table
            }
        )

        if not yt.exists(os.path.dirname(self.YT_PREFIX)):
            yt.create('map_node', os.path.dirname(self.YT_PREFIX), recursive=True)
        if not yt.exists(self.YT_PREFIX):
            yt.create('table', self.YT_PREFIX, attributes={
                "dynamic": True,
                "schema": [
                    { "name": "start", "type": "int64", "sort_order": "ascending" },
                    { "name": "end", "type": "int64", "sort_order": "ascending" },
                    { "name": "resource_id", "type": "string", "sort_order": "ascending" },
                    { "name": "metric", "type": "string", "sort_order": "ascending" },
                    { "name": "selectors", "type": "string", "sort_order": "ascending" },
                    { "name": "min", "type": "double" },
                    { "name": "max", "type": "double" },
                    { "name": "avg", "type": "double" },
                ],
            })
        yt.mount_table(self.YT_PREFIX, sync=True)
        from_ts = self.DEFAULT_FROM_TS
        # get last timestamp from table, assume that task hasn't finished
        last_start_result = list(yt.select_rows("start FROM [{}] LIMIT 1".format(self.YT_PREFIX)))
        if last_start_result:
            last_start_ts = datetime.fromtimestamp(-int(last_start_result[0]['start'])) - timedelta(seconds=1)
            # don't continue previous day
            from_ts = max(from_ts, last_start_ts)
            logging.info("Got last timestamp {} and continue from {}".format(last_start_ts, from_ts))
        to_ts = from_ts + timedelta(seconds=self.INTERVAL)
        while to_ts <= self.DEFAULT_TO_TS:
            for aggregate in self.AGGREGATES:
                logging.info('Importing {} from {} to {}'.format(aggregate, from_ts, to_ts))

                data = get_solomon_data(self.SOLOMON_PROJECT,
                                        self.SOLOMON_CLUSTER,
                                        self.SOLOMON_SERVICE,
                                        tmp_table,
                                        aggregate,
                                        from_ts + timedelta(seconds=1),
                                        to_ts,
                                        solomon_api,
                                        solomon_data_api)
                rows = parse_solomon(data, from_ts + timedelta(seconds=1), to_ts)
                for batch in grouper_it(self.Parameters.rows_size, rows):
                    yt.insert_rows(self.YT_PREFIX, batch, update=True)

            from_ts, to_ts = to_ts, to_ts + timedelta(seconds=self.INTERVAL)
        yt.unmount_table(self.YT_PREFIX)
        yt.remove(tmp_table)
