# -*- coding: utf-8 -*-
import numpy as np
from textwrap import dedent

from startrek_client import Startrek
from datacloud.input_pipeline.input_pipeline.star_tracker.st_logger_interface import ST_logger_interface


class ST_Logger(ST_logger_interface):
    USER_AGENT = 'robot-xprod'
    BASE_URL = 'https://st-api.yandex-team.ru/v2'
    IS_CREDIT_SCORING_TAG = 'credit_scoring'
    NOT_CREDIT_SCORING_TAG = 'not_credit_scoring'
    RETRO_TEST_TAG = 'retro_test'

    def __init__(self, token, ticket_name, useragent=None, base_url=None, comment_id=None):

        useragent = useragent or ST_Logger.USER_AGENT
        base_url = base_url or ST_Logger.BASE_URL

        self.st_client = Startrek(
            useragent=useragent,
            base_url=base_url,
            token=token
        )
        self.st_issue = self.st_client.issues[ticket_name]
        self.message = []
        self._comment_id = comment_id

        if self._comment_id:
            self.find_comment()

    def find_comment(self):
        if not self._comment_id:
            raise ValueError('No comment id provided!')
        comments = self.st_issue.comments.get_all()
        for comment in comments:
            if comment._value['id'] == self._comment_id:
                self.message = [comment._value['text']]
                self.comment = comment
                return
        raise ValueError('Bad comment id!')

    @property
    def comment_id(self):
        return self._comment_id

    @staticmethod
    def rows_gen(d, sort_by_num=False):
        if not sort_by_num:
            sorted_param = sorted(d.keys())
        else:
            table_tuples = [(key, value) for key, value in d.iteritems()]
            sorted_tt = sorted(table_tuples, key=lambda x: x[1], reverse=True)
            sorted_param = [tt[0] for tt in sorted_tt]
        for param in sorted_param:
            val = d[param]
            yield '|| {0}| {1}| {2:.2f}%||'.format(param, val['abs'], float(val['rel']) * 100)

    @staticmethod
    def make_table(parameter_name, d, sort_by_num=False):
        rows = '\n'.join(ST_Logger.rows_gen(d, sort_by_num=sort_by_num))
        return dedent("""\
            #|
            || {0}| abs| rel||
            {1}
            |#\
        """).format(parameter_name, rows)

    def write_initial_comment(self, ticket_name, input_file):
        text = dedent("""\
            **Uploaded Sample**
            %%
            {{
            'field_name': 'pipeline_{ticket_name}',
            'field_value': {input_file}
            }}
            %%"""
        ).format(  # noqa
            ticket_name=ticket_name,
            input_file=input_file
        )

        self.message.append(text)

    def write_link(self, data_dir):
        text = dedent("""
            ((https://yt.yandex-team.ru/hahn/navigation?path={0} {0}))"""
        ).format(data_dir)  # noqa

        self.message.append(text)

    def write_min_max_retro(self, min_retro_date, max_retro_date):
        text = dedent("""
            %%
            min retro date: {min_retro_date}
            max retro date: {max_retro_date}
            %%"""
        ).format(  # noqa
            min_retro_date=min_retro_date,
            max_retro_date=max_retro_date
        )

        self.message.append(text)

    @staticmethod
    def _get_months_column(months):
        sorted_months = sorted(months.keys())
        months_list = ['{}  {}'.format(month, months[month]) for month in sorted_months]

        return '\n'.join(months_list)

    def write_contracts_by_month(self, months_dict):
        months_cloumn = ST_Logger._get_months_column(months_dict)
        text = dedent("""
            <{{Contracts by month
            {months_cloumn}
            }}>
        """
        ).format(months_cloumn=months_cloumn)  # noqa

        self.message.append(text)

    def write_tables(self, tables_header, tables_dict, sort_by_num=False):
        tables = '\n'.join(
            ST_Logger.make_table(
                table_name,
                tables_dict[table_name],
                sort_by_num=sort_by_num
            ) for table_name in tables_dict
        )
        text = dedent("""
            **{tables_header}**
            {tables}"""
        ).format(  # noqa
            tables_header=tables_header,
            tables=tables
        )

        self.message.append(text)

    def push(self):
        comment_text = '\n'.join(self.message)
        if not self._comment_id:
            self.comment = self.st_issue.comments.create(text=comment_text)
            self._comment_id = self.comment._value['id']
        else:
            self.comment.update(text=comment_text)

        return self._comment_id

    def write_train_results(self, results, features_tag=''):
        text = dedent("""\
            **Train**
            {features_tag}\
        """).format(features_tag=features_tag)
        sorted_targets = sorted(results.keys())

        for target in sorted_targets:
            result = results[target]

            fold_results = result['fold_results']
            mean_train, std_train = np.mean(fold_results['train_auc']), np.std(fold_results['train_auc'])
            mean_val, std_val = np.mean(fold_results['val_auc']), np.std(fold_results['val_auc'])

            def auc_str(train, val):
                return '{0:.4f}  {1:.4f}'.format(train, val)
            aucs = '\n'.join(auc_str(t, v) for t, v in zip(
                fold_results['train_auc'],
                fold_results['val_auc'])
            )

            text += dedent("""

                Target: {2}
                Best score {0:.4f}
                Best C param {1:.2f}\
            """).format(result['best_score'], result['best_params']['C'], target)

            text += dedent("""
                %%
                Train   Val
                {0}
                %%

                Train AUC: {1:.4f} {2:.4f}
                Val AUC: {3:.4f} {4:.4f}
            """).format(aucs, mean_train, std_train, mean_val, std_val)

        self.message.append(text)

    def drop_message(self):
        self.message = []

    def update_tags(self, partner_id, is_credit_scoring, other_tags=None):
        other_tags = other_tags or []
        tags = set(self.st_issue.tags)
        # Delete credit scorings tags if existss
        tags = list(tags.difference(
            self.IS_CREDIT_SCORING_TAG,
            self.NOT_CREDIT_SCORING_TAG
        ))
        if is_credit_scoring:
            tags.append(self.IS_CREDIT_SCORING_TAG)
        else:
            tags.append(self.NOT_CREDIT_SCORING_TAG)
        tags.extend([
            partner_id,
            self.RETRO_TEST_TAG
        ] + other_tags)

        self.st_issue.update(tags=tags)
