# -*- coding: utf-8 -*-

import datetime
import logging
import os

from sandbox.sandboxsdk import environments
from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.parameters import ResourceSelector, SandboxIntegerParameter, SandboxStringParameter
from sandbox.sandboxsdk.paths import add_write_permissions_for_path, copy_path
from sandbox.sandboxsdk.process import run_process
from sandbox.sandboxsdk.task import SandboxTask

from sandbox import common
from sandbox.projects import resource_types
from sandbox.projects.common import apihelpers
from sandbox.projects.common.utils import get_or_default


# Define rules how to convert source graph resource to a resource with stats.
GRAPH_RESOURCE_MAPPING = {
    resource_types.RTYSERVER_GRAPH_DIR_PRODUCTION_RUS: resource_types.RTYSERVER_GRAPH_WITH_HISTORY_DIR_PRODUCTION_RUS,
}


class RoadsGraphToolsBinary(ResourceSelector):
    name = 'roads_graph'
    description = 'Binary from saas/tools/roads_graph'
    resource_type = resource_types.RTYSERVER_UTILS_ROADS_GRAPH

    @common.utils.classproperty
    def default_value(cls):
        return apihelpers.get_last_released_resource(RoadsGraphToolsBinary.resource_type).id


class GraphResourceTypeParameter(SandboxStringParameter):
    name = 'graph_resource_type'
    description = 'Target graph resource type'
    required = True


class GraphResourceIdParameter(SandboxIntegerParameter):
    name = 'graph_resource_id'
    description = 'Source graph resource id'
    required = True


class YtServerParameter(SandboxStringParameter):
    name = 'yt_server'
    description = 'YT server'
    default_value = 'freud'


class YtTokenVaultNameParameter(SandboxStringParameter):
    name = 'yt_token_vault'
    description = 'YT token vault key'
    default_value = 'yt_freud_token'


class YtPoolParameter(SandboxStringParameter):
    name = 'yt_pool'
    description = 'YT pool'
    default_value = 'extdata'


class YtWorkDirParameter(SandboxStringParameter):
    name = 'yt_work_dir'
    description = 'YT working directory'
    required = True
    default_value = '//home/extdata/rtline/traffic-history'


class AnalyzerLogDirParameter(SandboxStringParameter):
    name = 'analyzer_log_dir'
    description = 'Analyzer dispatcher log directory'
    default_value = '//home/logfeller/logs/analyzer-dispatcher-signals-log/1d'


class ImportTrackThresholdParameter(SandboxIntegerParameter):
    name = 'track_treshold'
    description = 'Discard edge stats based on a small number of tracks'
    default_value = 0


class ImportSpeedThresholdParameter(SandboxIntegerParameter):
    name = 'speed_threshold'
    description = 'Discard edge stats having unusually low speed'
    default_value = 0


class MaxAgeParameter(SandboxIntegerParameter):
    name = 'max_age'
    description = 'Max age for tracks'
    default_value = 0


class YtLoggingLevel(SandboxIntegerParameter):
    name = 'yt_logging_level'
    description = 'YT logging level: 10 - debug, 20 - info, etc., see: https://docs.python.org/3/library/logging.html#logging-levels'
    default_value = 20


class SaasRoadsGraphTrafficHistoryBuilder(SandboxTask):

    type = 'SAAS_ROADS_GRAPH_TRAFFIC_HISTORY_BUILDER'

    input_parameters = [
        RoadsGraphToolsBinary,
        GraphResourceTypeParameter,
        GraphResourceIdParameter,
        YtWorkDirParameter,
        YtServerParameter,
        YtTokenVaultNameParameter,
        YtPoolParameter,
        AnalyzerLogDirParameter,
        ImportTrackThresholdParameter,
        ImportSpeedThresholdParameter,
        MaxAgeParameter,
        YtLoggingLevel,
    ]

    environment = (
        environments.PipEnvironment('yandex-yt', use_wheel=True),
    )

    TIMEOUT = 36 * 3600
    RAM = 81920
    EXECUTION_SPACE = 64*1024

    GRAPH_WITH_HISTORY_ID_ATTRIBUTE = 'graph_with_history_id'

    def on_execute(self):
        import yt.wrapper as yt

        roads_graph_binary = self.sync_resource(get_or_default(self.ctx, RoadsGraphToolsBinary))

        yt_server = get_or_default(self.ctx, YtServerParameter)
        yt_token_vault_key = self.ctx[YtTokenVaultNameParameter.name]
        yt_token = self.get_vault_data(self.owner, yt_token_vault_key)
        yt_pool = get_or_default(self.ctx, YtPoolParameter)
        logging.getLogger("Yt").setLevel(get_or_default(self.ctx, YtLoggingLevel))

        yt_config = {
            "write_parallel": {
                "enable": True,
                "max_thread_count": 3
            }
        }
        yt_client = yt.client.Yt(proxy=yt_server, token=yt_token, config=yt_config)

        yt_work_dir = self.ctx[YtWorkDirParameter.name]
        yt_graph_dir = self.get_yt_graph_dir(yt_work_dir)
        yt_aggregated_speeds_table = os.path.join(yt_work_dir, 'edge-speeds')
        yt_analyzer_log_dir = get_or_default(self.ctx, AnalyzerLogDirParameter)

        graph_resource_id = self.ctx[GraphResourceIdParameter.name]
        graph_resource = channel.sandbox.get_resource(graph_resource_id)
        remote_graph_path = channel.task.sync_resource(graph_resource_id)
        graph_path = self.path('graph_info')
        copy_path(remote_graph_path, graph_path)
        add_write_permissions_for_path(graph_path)
        hist_path = os.path.join(graph_path, 'edge_history')

        # Prepare graph diff.
        prev_graph = self.get_prev_graph(graph_resource)
        if prev_graph:
            if prev_graph.id > graph_resource.id:
                # Don't allow to specify old graph revisions.
                logging.error(
                    'Current graph %s is older that the most recent graph with history %s. Aborting',
                    graph_resource.id,
                    prev_graph.id,
                )
                return
            if prev_graph.id < graph_resource.id:
                # Calculate diff with a previous graph.
                prev_graph_path = self.sync_resource(prev_graph.id)
                yt_graph_diff_path = self.prepare_yt_graph_diff(roads_graph_binary,
                                                                yt_work_dir,
                                                                prev_graph_path,
                                                                graph_path,
                                                                yt_client)
            else:
                # Don't calculate diff with the same graph.
                yt_graph_diff_path = None
        else:
            yt_graph_diff_path = None

        import_track_threshold = self.ctx[ImportTrackThresholdParameter.name]
        import_speed_threshold = self.ctx[ImportSpeedThresholdParameter.name]
        max_age = self.ctx[MaxAgeParameter.name]

        if not yt_client.exists(yt_aggregated_speeds_table):
            self.upload_graph_to_yt(graph_path, yt_graph_dir, yt_client)
            self.update_index(
                roads_graph_binary, yt_server, yt_token, yt_pool, yt_work_dir,
                yt_graph_dir, yt_graph_diff_path, yt_analyzer_log_dir, yt_client, max_age
            )
        self.import_history(
            roads_graph_binary, yt_server, yt_token, yt_pool,
            yt_aggregated_speeds_table, import_track_threshold,
            import_speed_threshold, hist_path
        )

        graph_with_history = self.create_resource(
            description=graph_resource.description,
            resource_path=graph_path,
            resource_type=self.ctx[GraphResourceTypeParameter.name],
        )

        ok = channel.sandbox.set_resource_attribute(
            graph_resource_id,
            self.GRAPH_WITH_HISTORY_ID_ATTRIBUTE,
            str(graph_with_history.id)
        )
        assert ok

    def get_prev_graph(self, cur_graph_resource):
        '''Find the most recently indexed graph resource for a diff.'''

        resources = channel.sandbox.list_resources(
            resource_type=cur_graph_resource.type,
            omit_failed=True,
            limit=100,
        )

        prev_graph = None
        for resource in resources:
            if self.GRAPH_WITH_HISTORY_ID_ATTRIBUTE in resource.attributes:
                prev_graph = resource
                break

        return prev_graph

    def prepare_yt_graph_diff(self, roads_graph_binary, yt_work_dir, prev_graph, cur_graph, yt_client):
        diff_path = self.path('graph.diff')
        cmd = [
            roads_graph_binary,
            'diff',
            os.path.dirname(prev_graph),  # graph_info parent is expected.
            os.path.dirname(cur_graph),
        ]
        with open(diff_path, 'w') as f:
            logging.info("prepare_yt_graph_diff: run_process: {}".format(cmd))
            run_process(cmd, log_prefix='diff_graphs', stdout=f, outputs_to_one_file=False)

        if not yt_client.exists(yt_work_dir):
            logging.info("prepare_yt_graph_diff: yt_client.create: {}".format(yt_work_dir))
            yt_client.create('map_node', yt_work_dir, recursive=True)

        yt_graph_diff_path = os.path.join(yt_work_dir, 'graph.diff')
        with open(diff_path) as f:
            logging.info("prepare_yt_graph_diff: yt_client.write_file: {}".format(yt_graph_diff_path))
            yt_client.write_file(yt_graph_diff_path, f, force_create=True)

        return yt_graph_diff_path

    def get_yt_graph_dir(self, yt_work_dir):
        timestamp = datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S')
        yt_graph_dir = os.path.join(yt_work_dir, 'road_graphs', timestamp)
        return yt_graph_dir

    def upload_graph_to_yt(self, local_graph_info_dir, yt_graph_parent_dir, yt_client):
        yt_graph_info_dir = os.path.join(yt_graph_parent_dir, 'graph_info')
        if not yt_client.exists(yt_graph_info_dir):
            yt_client.create('map_node', yt_graph_info_dir, recursive=True)

        # Automatically remove old graphs in a week.
        expiration_time = datetime.datetime.utcnow() + datetime.timedelta(days=7)
        yt_client.set_attribute(yt_graph_parent_dir, 'expiration_time', expiration_time.isoformat())

        for filename in os.listdir(local_graph_info_dir):
            local_path = os.path.join(local_graph_info_dir, filename)
            yt_path = os.path.join(yt_graph_info_dir, filename)
            logging.info('upload_graph_to_yt: {} -> {}'.format(local_path, yt_path))
            yt_client.write_file(yt_path, open(local_path))

    def update_index(self, roads_graph_binary, yt_server, yt_token, yt_pool,
                     yt_work_dir, yt_graph_dir, yt_graph_diff, yt_analyzer_log_dir, yt_client, max_age):
        cmd = [
            roads_graph_binary,
            'traffic_history',
            '--yt-server', yt_server,
            '--yt-work-dir', yt_work_dir,
            '--yt-graph-dir', yt_graph_dir,
            '--yt-analyzer-log-dir', yt_analyzer_log_dir,
        ]
        if max_age and max_age > 0:
            cmd += [
                '--filter-max-age', str(max_age)
            ]
        if yt_graph_diff:
            cmd += [
                '--yt-graph-diff', yt_graph_diff,
            ]
        env = {
            'YT_LOG_LEVEL': 'INFO',
            'YT_POOL': yt_pool,
            'YT_TOKEN': yt_token,
        }
        logging.info("update_index: run_process: {}".format(cmd))
        run_process(cmd, environment=env, log_prefix='update_index')

    def import_history(self, roads_graph_binary, yt_server, yt_token, yt_pool,
                       yt_aggregated_speeds_table, track_threshold, speed_threshold, out_path):
        cmd = [
            roads_graph_binary,
            'import_history',
            '--track-threshold', str(track_threshold),
            '--speed-threshold', str(speed_threshold),
            '--yt-server', yt_server,
            '--yt-table', yt_aggregated_speeds_table,
            out_path,
        ]
        env = {
            'YT_LOG_LEVEL': 'INFO',
            'YT_POOL': yt_pool,
            'YT_TOKEN': yt_token,
        }
        logging.info("import_history: run_process: {}".format(cmd))
        run_process(cmd, environment=env, log_prefix='import_history')


__Task__ = SaasRoadsGraphTrafficHistoryBuilder
