import yaml
import datetime
import json
import os
from sandbox import sdk2
from sandbox.common.types.resource import State

from sandbox.projects import resource_types
from sandbox.projects.ads.online_learning.resources import OnlineOfflineMetricsBinary, OnlineOfflineMetricsState
import logging
import shutil


TIME_FORMAT = "%Y%m%d%H%M"
LIMIT_FOR_FEATURE_MAPS = 10

DEFAULT_METRICS = ','.join(
    [
        "dump_size",
        "number_of_features",
        "llp",
        "ctr_factor",
        "unweighted_ctr_factor",
        "clicks",
        "shows",
        "unweighted_clicks",
        "unweighted_shows"
    ]
)


def convert_resource_2_path(resource):
    return str(sdk2.ResourceData(resource).path)


def get_latest_state_path(ml_task_id):
    resource = OnlineOfflineMetricsState.find(
        attrs={"ml_task_id": ml_task_id},
        state=(State.READY,),
    ).order(-sdk2.Resource.id).first()
    if resource:
        logging.info('ResourceID = {resource_id}'.format(resource_id=str(resource)))
        resource_data = sdk2.ResourceData(resource)
        return str(resource_data.path)


def get_dump_resources(start_time, end_time, task_id, task_id_field, feature_map_type):
    resources = sdk2.Resource.find(
        resource_type=feature_map_type,
        state=State.READY,
        attrs={task_id_field: task_id}
    ).order(-sdk2.Resource.id).limit(LIMIT_FOR_FEATURE_MAPS)
    if not resources:
        resources = []
    resources = [resource for resource in resources if start_time <= str(resource.log_time) < end_time]
    return resources


def fetch_released_binary(stable):
    release_status = 'stable' if stable else 'testing'
    logging.info("Release status: {release_status}".format(release_status=release_status))
    resource = sdk2.Resource.find(
        resource_type=OnlineOfflineMetricsBinary,
        attrs={"released": release_status}
    ).order(-sdk2.Resource.id).first()
    logging.info('ResourceID = {resource_id}'.format(resource_id=str(resource)))
    resource_data = sdk2.ResourceData(resource)
    return str(resource_data.path)


class ComputeOfflineMetricsOnlineLearning(sdk2.Task):
    """Compute simple metrics"""

    class Requirements(sdk2.Requirements):
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        metrics_to_compute = sdk2.parameters.String("Metrics to compute", required=False)
        ml_task_id = sdk2.parameters.String("Task ID", required=True)
        compute_metrics_state = sdk2.parameters.Resource(
            'Optional process state',
            resource_type=resource_types.OFFLINE_METRICS_ONLINE_LEARNING_STATE,
            required=False,
        )

        task_yml = sdk2.parameters.Resource(
            'task.yml',
            resource_type=resource_types.OTHER_RESOURCE,
            required=False,
        )

        yt_token_owner = sdk2.parameters.String(
            "YT token owner",
            required=True,
            default="ML-ENGINE"
        )

        yt_token_name = sdk2.parameters.String(
            "YT token name",
            required=True,
            default="robot_ml_engine_hahn_yt_token"
        )

        start_time = sdk2.parameters.String(
            "First last_log_date of task, if no state found, format=YmdHM",
            required=False,
        )

        end_time = sdk2.parameters.String(
            "Last last_log_date to compute, default is now, format=YmdHM",
            required=False,
        )

        hour_delay = sdk2.parameters.String(
            "Delay (apply dump with time on [time+delay, time+delay+1) logs)",
            required=True,
            default_value=1
        )

        stable = sdk2.parameters.Bool(
            "Use stable binary",
            required=False,
            default_value=True
        )

    def on_execute(self):
        binary = fetch_released_binary(self.Parameters.stable)

        task_file = convert_resource_2_path(self.Parameters.task_yml)
        ml_task_id = self.Parameters.ml_task_id

        end_time, start_time = self.parse_dates(self.Parameters)
        state_resource_id = self.Parameters.compute_metrics_state
        state_file = "compute_offline_metrics_online_learning.state.json"
        if not state_resource_id:
            resource_path = get_latest_state_path(ml_task_id)
            if resource_path:
                shutil.copyfile(resource_path, state_file)
                with open(state_file, 'r') as state:
                    process_state = json.load(state)
                    if process_state["last_time"] != "0":
                        start_time = datetime.datetime.strptime(process_state["last_time"],
                                                                TIME_FORMAT) + datetime.timedelta(hours=1)
        logging.info(
            '[start_time: %s, end_time: %s]' % (start_time.strftime(TIME_FORMAT), end_time.strftime(TIME_FORMAT)))
        dump_resources = get_dump_resources(
            start_time.strftime(TIME_FORMAT),
            end_time.strftime(TIME_FORMAT),
            ml_task_id,
            task_id_field='ml_task_id',
            feature_map_type=resource_types.ONLINE_LEARNING_FEATURE_MAP
        )

        dumps_with_info = []
        for resource in dump_resources:
            dumps_with_info.append(
                {
                    'dump_file_name': convert_resource_2_path(resource),
                    'last_log_date': str(resource.log_time)
                }
            )

        logging.info("Dumps found: %s" % (str(dumps_with_info)))
        metrics_to_compute = self.Parameters.metrics_to_compute
        if not metrics_to_compute:
            metrics_to_compute = DEFAULT_METRICS
        logging.info(metrics_to_compute)
        context = {
            "compute_offline_metrics_online_learning": {
                "state": state_file,
                "time_format": TIME_FORMAT,
                "task_id": ml_task_id,
                "metrics_to_compute": metrics_to_compute,
                "dumps_with_info": dumps_with_info,
                "hour_delay": self.Parameters.hour_delay,
                "graphite": {
                    "graphite_prefix": "one_hour.online_learning.offline_metrics",
                    "attempts": 6,
                    "delay": 600
                }
            }
        }

        with open("compute_offline_metrics.conf", "w") as out:
            yaml.dump(context, out)
        logging.info(context)
        logging.info(binary)
        logging.info(os.listdir('.'))

        env = self.prepare_env()

        cmd = [
            binary,
            "--conf", "compute_offline_metrics.conf",
            "--task", task_file
        ]
        logging.info(cmd)
        with sdk2.helpers.ProcessLog(self, logger='compute_metrics') as pl:
            sdk2.helpers.subprocess.check_call(
                cmd,
                env=env,
                stdout=pl.stdout,
                stderr=pl.stderr
            )

        self.dump_state(state_file, ml_task_id)

    def dump_state(self, state_file, ml_task_id):
        state = sdk2.ResourceData(
            OnlineOfflineMetricsState(
                self,
                "State for computing offline metrics, ml_task_id = %s" % ml_task_id,
                path=state_file,
                ml_task_id=ml_task_id
            )
        )
        state.ready()

    def prepare_env(self):
        yt_token = sdk2.Vault.data(self.Parameters.yt_token_owner, self.Parameters.yt_token_name)
        env = os.environ.copy()
        env['YT_TOKEN'] = yt_token
        return env

    def parse_dates(self, params):
        try:
            start_time = datetime.datetime.strptime(params.start_time, TIME_FORMAT)
        except ValueError:
            start_time = datetime.datetime(1901, 1, 1)
        try:
            end_time = datetime.datetime.strptime(params.end_time, TIME_FORMAT)
        except ValueError:
            end_time = datetime.datetime.now()
        return end_time, start_time
