from resources import MlFetchDumpBinary, MlEngineScheduleProgram
import sandbox.common.types.resource as ctr
from datetime import datetime, date
import csv
import yaml
import os
import json
import requests
import shutil

from sandbox import sdk2
import sandbox.sandboxsdk.process as sdk_process
from sandbox.common import rest
import logging


def list_task_dates_on_sandbox(binary, task_id):
    process = sdk_process.run_process(
        '{binary} -d {task_id} list'.format(
            **dict(binary=binary, task_id=task_id)
        ),
        shell=True,
        outputs_to_one_file=False,
        wait=False,
        log_prefix='ml_fetch_dump_list'
    )
    process.wait()
    ready_dates = []
    with open(process.stdout_path) as f:
        ready_dates = f.read().strip().split()
    sdk_process.check_process_return_code(process)
    return ready_dates


def download_task_metrics(binary, task_id):

    def get_metrics_from_old_format_dump(path_to_metrics):
        learn_metrics = {}
        test_metrics = {}
        for r in csv.DictReader(open(path_to_metrics), dialect=csv.excel_tab):
            learn_metrics[r['# metric']] = float(r['learn'])  # suppose that learn metrics are always presented
            test_metrics[r['# metric']] = float(r['test']) if 'test' in r else None
        return {
            'learn': learn_metrics,
            'test': test_metrics,
        }

    def get_metrics_from_new_format_dump(path_to_metrics):
        new_task_metrics = json.load(open(path_to_metrics))
        return {
            'learn': [x for x in new_task_metrics if x['pool'] == 'train'][0]['total']['target_stats'],
            'test': [x for x in new_task_metrics if x['pool'] == 'test'][0]['total']['target_stats'],
        }

    task_dir = os.path.join(os.getcwd(), task_id.replace('/', '_'))
    sdk_process.run_process(
        '{binary} -d {task_id} --dst {dst} fetch'.format(
            **dict(binary=binary, task_id=task_id, dst=task_dir)
        ),
        shell=True,
        outputs_to_one_file=False,
        wait=True,
        log_prefix='ml_fetch_dump_fetch'
    )
    logging.info("Current directory: %s" % os.getcwd())
    logging.info("List dir in current directory: %s" % os.listdir(os.getcwd()))
    try:
        metrics = get_metrics_from_old_format_dump(os.path.join(task_dir, 'resources', 'target_stats.t'))
        logging.info("metrics in common format: %s" % json.dumps(metrics, indent=4))
    except Exception, e:
        logging.error('Error happened while trying to get metrics from old dump format: %s' % e)
        metrics = get_metrics_from_new_format_dump(os.path.join(task_dir, 'target_stats.json'))
    if not metrics:
        raise ValueError("Can not get metrics for task %s" % task_id)

    shutil.rmtree(task_dir)

    return metrics


def are_tasks_ready(binary, task_ids_with_dates):
    task_statuses = []
    for task_info in task_ids_with_dates:
        task_id, last_log_date = task_info.split('/')
        ready_tasks = set(list_task_dates_on_sandbox(binary, task_id))
        is_ready = False
        if last_log_date in ready_tasks:
            is_ready = True
        task_statuses.append(is_ready)

    if all(task_statuses):
        return True
    return False


def kill_scheduler_from_task(task):
    sandbox_client = rest.Client()
    # here also some weird code - scheduler id can not be accessed
    # by rest.Client so lets get task_url from rest.Client
    # and go to this url by requests.get() to get scheduler id
    task_info = sandbox_client.task[task.id].read()
    response = requests.get(task_info['url'])
    if response.status_code != 200:
        logging.warn("Request to %s returned with status code %s" % (task_info['url'], response.status_code))
        return
    j = json.loads(response.text)
    scheduler_id = j.get('scheduler', {}).get('id')
    logging.info('Here we are going to stop scheduler with id %s' % scheduler_id)
    try:
        logging.info('Entering to code that will try to delete scheduler')
        # weird way to check that this scheduler exists -
        # try to throw exception while accessing scheduler info
        scheduler_info = sandbox_client.scheduler[scheduler_id].read()
        logging.info(scheduler_info)
        logging.info('Deleting scheduler %s' % scheduler_id)
        del sandbox_client.scheduler[scheduler_id]
    except sandbox_client.HTTPError:
        logging.warn('Seems like you are trying to delete scheduler that does not exist')


def order_lms_by_metrics(binary, tasks, choose_best_lm_strategy):
    """
    Return lms in reversed order of lambda function passed to best_lm_strategy arg
    """
    f = eval(choose_best_lm_strategy)
    d = {t: download_task_metrics(binary, t) for t in tasks}

    return [z[0] for z in sorted(d.iteritems(), key=lambda x: f(x[1]), reverse=True)]


def run_mx_task(schedule_binary, mx_last_log_date, mx_task_file_name, max_attempt_time, oauth_token):
    cmd = '{schedule_binary}'
    cmd += ' --is-one-shot'
    cmd += ' --from-time {mx_last_log_date_timestamp}'
    cmd += ' --to-time {mx_last_log_date_timestamp_plus_epsilon}'
    cmd += ' --max-attempt-time {days_from_mx_last_log_date}'
    cmd += ' --yt-pool ml-engine'
    cmd += ' --period 5'  # fictive parameter to save common interface for ./schedule
    cmd += ' {mx_task_file_name}'
    cmd += ' --static'
    cmd += ' --combine-tasks'
    cmd += ' --prod-dump'
    cmd += ' --owner ML-ENGINE'
    cmd += ' --sandbox-vault-name robot-ml-engine_ml_engine_app_token'
    cmd += ' --nirvana-secret robot-ml-engine_ml_engine_app_token'
    cmd += ' --token {token}'
    mx_last_log_date_timestamp = int((mx_last_log_date - date(1970, 1, 1)).total_seconds())
    mx_last_log_date_timestamp_plus_epsilon = mx_last_log_date_timestamp + 1
    process = sdk_process.run_process(
        cmd.format(
            **dict(schedule_binary=schedule_binary,
                   mx_last_log_date_timestamp=mx_last_log_date_timestamp,
                   mx_last_log_date_timestamp_plus_epsilon=mx_last_log_date_timestamp_plus_epsilon,
                   days_from_mx_last_log_date=(date.fromtimestamp(max_attempt_time) - mx_last_log_date).days,
                   mx_task_file_name=mx_task_file_name,
                   token=oauth_token
                   )
        ),
        shell=True,
        outputs_to_one_file=False,
        wait=False,
        log_prefix='schedule_run'
    )
    process.wait()
    sdk_process.check_process_return_code(process)
    return


class AutoLmEvalWatcher(sdk2.Task):
    """Clone and run nirvana graph, changing some options"""
    class Parameters(sdk2.Task.Parameters):
        tasks_to_wait = sdk2.parameters.String(
            "Linear models to wait, space separated, format task_id/task_log_date task_id/last_log_date",
            required=True
        )
        max_attempt_time = sdk2.parameters.Integer(
            "Till with time to work - timestamp",
            required=True
        )
        # mx_last_log_date_time = sdk2.parameters.Integer(
        #     "Last log date for matrixnet",
        #     required=True
        # )
        choose_best_lm_strategy = sdk2.parameters.String(
            "Strategy to choose best lm",
            required=False,
            default='lambda x: x["test"]["ll_p"]'
        )
        mx_task_text = sdk2.parameters.String(
            "Mx task text without task.data.models info",
            required=True,
            multiline=True
        )
        free_mx_slots = sdk2.parameters.String(
            "Mx free slots, ints separated by space",
            required=True
        )
        ml_fetch_dump_binary = sdk2.parameters.Resource(
            'Program fetching ready tasks',
            resource_type=MlFetchDumpBinary,
            state=(ctr.State.READY,),
            required=False,
        )
        ml_fetch_dump_search_config = sdk2.parameters.String(
            '''Json with parameters of searching ml_fetch_dump program''',
            multiline=True
        )
        ml_engine_schedule_binary = sdk2.parameters.Resource(
            'Program running ml_engine_tasks with schedule',
            resource_type=MlEngineScheduleProgram,
            state=(ctr.State.READY,),
            required=False,
        )
        ml_engine_schedule_search_config = sdk2.parameters.String(
            '''Json with parameters of searching schedule program''',
            multiline=True
        )
        how_many_best_lms_try = sdk2.parameters.Integer(
            "How many best lms try in mx",
            required=False,
        )

    def on_save(self):
        logging.info('on save called')
        if self.Parameters.ml_fetch_dump_search_config:
            params = json.loads(self.Parameters.ml_fetch_dump_search_config)

            logging.info('params: %s', params)
            self.Parameters.ml_fetch_dump_binary = sdk2.Resource.find(**params).order(-sdk2.Resource.id).first()
        else:
            logging.info('no config found: %s', self.Parameters.__getstate__())
        if self.Parameters.ml_engine_schedule_search_config:
            params = json.loads(self.Parameters.ml_engine_schedule_search_config)

            logging.info('params: %s', params)
            self.Parameters.ml_engine_schedule_binary = sdk2.Resource.find(**params).order(-sdk2.Resource.id).first()
        else:
            logging.info('no config found: %s', self.Parameters.__getstate__())

    def on_execute(self):
        ml_fetch_dump_binary = str(sdk2.ResourceData(self.Parameters.ml_fetch_dump_binary).path)
        schedule_binary = str(sdk2.ResourceData(self.Parameters.ml_engine_schedule_binary).path)
        task_ids_to_wait = self.Parameters.tasks_to_wait.split()
        max_attempt_time = self.Parameters.max_attempt_time

        if not are_tasks_ready(ml_fetch_dump_binary, task_ids_to_wait):
            return  # just do nothing

        if datetime.now() > datetime.fromtimestamp(self.Parameters.max_attempt_time):
            # TODO: add logging
            kill_scheduler_from_task(self)
            return  # kill scheduler that is watching for this task if time is out

        # if this param is not passed then use only one best lm
        how_many_best_lms_try = self.Parameters.how_many_best_lms_try or 1
        ordered_tasks_ids = order_lms_by_metrics(ml_fetch_dump_binary, task_ids_to_wait, self.Parameters.choose_best_lm_strategy)
        best_lms = ordered_tasks_ids[:how_many_best_lms_try]
        free_mx_slots_ids = [int(x) for x in self.Parameters.free_mx_slots.split()]

        lm_in_mx_pattern = {'data': {}}
        lm_in_mx_pattern['data']['models'] = {
            'baseline': {
                'tasks': []
            }
        }
        for best_lm in best_lms:
            task_id = str(best_lm.split('/')[0])
            lm_in_mx_pattern['data']['models'][task_id] = {  # str here is because yaml loads strings as unicode
                'tasks': [
                    {'target_field': 'FLM%d' % free_mx_slots_ids[0],
                     'vivisection': False,
                     'task_id': task_id}
                    # here we got bad "grabli" - task_id and task_ids are ambigously common.
                    # TODO: fix this ambiguocity somehow
                ]
            }

        # mx task should be inflated but not expanded
        # this means all includes should be done but generate section should not be called
        mx_task_parts = list(yaml.load_all(self.Parameters.mx_task_text))
        mx_last_log_date = None
        for part in mx_task_parts:
            mx_last_log_date = part.get('last_log_date')
        if mx_last_log_date is None:
            raise ValueError('Can not find value for mx task last_log_date')
        # mx_task['data']['models'] = {}   --- this functionality should be realised on the side of caller of this code

        final_mx_task_text = '\n---\n'.join([
            self.Parameters.mx_task_text,
            yaml.safe_dump(lm_in_mx_pattern),
        ])

        mx_task_file_name = 'mx_task_text.yml'
        with open(mx_task_file_name, 'w') as f:
            f.write(final_mx_task_text)

        oauth_token = sdk2.Vault.data(
            'ML-ENGINE',
            'robot-ml-engine_ml_engine_app_token'
        )

        logging.info(final_mx_task_text)
        run_mx_task(schedule_binary, mx_last_log_date, mx_task_file_name, max_attempt_time, oauth_token)
        kill_scheduler_from_task(self)
        return
