# -*- coding: utf-8 -*-
from datetime import datetime
import json
from textwrap import dedent

import vh

from datacloud.dev_utils.yt.yt_utils import get_yt_client
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.ml_utils.vh_wrapper.graph_builder import GraphBuilder
from datacloud.ml_utils.benchmark.benchmark_results_table import BenchmarkResultsTable
from datacloud.ml_utils.benchmark.cubes import (
    process_bm_results, process_gs_results
)
from datacloud.ml_utils.benchmark.bm_solomon_helpers import upload_bm_to_solomon
from datacloud.dev_utils.time.patterns import FMT_DATE_HMST
from datacloud.ml_utils.vh_wrapper.helpers.cubes import run_grid_search


def _check_table_sorted(yt_client, table_path, sorted_by='external_id'):
    assert yt_client.get_attribute(table_path, 'sorted'), \
        'Table {} is not sorted!'.format(table_path)
    assert yt_client.get_attribute(table_path, 'sorted_by')[:1] == ['external_id'], \
        'Table {} should be sorted by external_id first!'.format(table_path)

    return True


class BenchmarkGraphBuilder(GraphBuilder):
    DEFAULT_WORKFLOW_GUID = '12042491-fb32-4a9e-8ff0-d03e87ec9c82'

    def __init__(self, ticket_name, description, workflow_guid=DEFAULT_WORKFLOW_GUID,
                 logger=None, yt_client=None):
        super(BenchmarkGraphBuilder, self).__init__(workflow_guid=workflow_guid)

        self.ticket_name = ticket_name
        assert isinstance(self.ticket_name, basestring)
        self.description = description
        assert isinstance(self.description, basestring)

        self.logger = logger or get_basic_logger(__name__)
        self.yt_client = yt_client or get_yt_client()
        self.creation_time = datetime.now().strftime(FMT_DATE_HMST)
        self._wrokflow_info = None
        self._results = None
        self._graph_keeper = None

    def _add_grid_searches(self, paths_to_yt_full_hit, paths_to_yt_on_cs):
        bm_results_files = []
        for dataset_name, table_path in paths_to_yt_full_hit.iteritems():
            gs_results = run_grid_search(
                params=json.dumps({'table_path': table_path}),
                ticket_name=self.ticket_name
            )
            bm_results_files.append(process_gs_results(
                gs_results=gs_results,
                dataset_name=dataset_name,
                is_on_clickstream=False
            ))

        for dataset_name, table_path in paths_to_yt_on_cs.iteritems():
            gs_results = run_grid_search(
                params=json.dumps({'table_path': table_path}),
                ticket_name=self.ticket_name
            )
            bm_results_files.append(process_gs_results(
                gs_results=gs_results,
                dataset_name=dataset_name,
                is_on_clickstream=True
            ))

        return bm_results_files

    def _was_started(self):
        if self._graph_keeper is None:
            return False
        return True

    def is_done(self):
        if not self._was_started():
            raise RuntimeError('Run graph first!')

        return self.graph_keeper.get_total_completion_future().done()

    @property
    def graph_keeper(self):
        if not self._was_started():
            raise RuntimeError('Run graph first!')
        return self._graph_keeper

    @property
    def wrokflow_info(self):
        if self._wrokflow_info is None:
            self.graph_keeper.get_workflow_info()

        return self._wrokflow_info

    @property
    def results(self):
        if self._results is None:
            self._results = self.graph_keeper.download(self.results_vh_file)

        return self._results

    def bless(self):
        bmr_table = BenchmarkResultsTable(yt_client=self.yt_client)
        nirvana_graph = 'https://nirvana.yandex-team.ru/flow/{}/{}/graph'.format(
            self.wrokflow_info.workflow_id,
            self.wrokflow_info.workflow_instance_id
        )
        result = {
            'ts_hr_format': self.creation_time,
            'ticket': self.ticket_name,
            'nirvana_graph': nirvana_graph,
            'description': self.description,
            'auc': self.results['auc'],
            'auc_on_cs': self.results['auc_on_cs'],
            'std': self.results['std'],
            'std_on_cs': self.results['std_on_cs']
        }
        bmr_table.add_record(**result)
        upload_bm_to_solomon(**result)

        self.logger.info(self.results)

    def run(self, paths_to_yt_full_hit, paths_to_yt_on_cs):
        self.logger.info(dedent("""\
            Input tables are:
            Full hit:
            {full_hit_tables}

            On clickstream:
            {on_clickstream_tables}
        """.format(
            full_hit_tables='\n'.join(paths_to_yt_full_hit.values()),
            on_clickstream_tables='\n'.join(paths_to_yt_on_cs.values())
        )))

        for path_to_yt in paths_to_yt_full_hit.values():
            _check_table_sorted(self.yt_client, path_to_yt)
        for path_to_yt in paths_to_yt_on_cs.values():
            _check_table_sorted(self.yt_client, path_to_yt)

        self.logger.info('Input tables checked!')

        with vh.Graph():
            bm_results_files = self._add_grid_searches(
                paths_to_yt_full_hit, paths_to_yt_on_cs)

            self.results_vh_file = process_bm_results(
                bm_results_files, st_token=self.st_token,
                ticket=self.ticket_name, description=self.description)

            self._graph_keeper = self._run_async()
            self.logger.info('Building your graph...')
