# -*- coding: utf-8 -*-

from collections import defaultdict
from datetime import datetime, timedelta
import textwrap
import requests

from sandbox import sdk2
from sandbox.common.errors import TaskFailure


def try_float(s):
    try:
        return float(s)
    except ValueError:
        return s


def precision(row):
    f = row["filtered"]
    o = row["overfiltered"]
    return 1.0 if f + o == 0.0 else f / (f + o)


def recall(row):
    f = row["filtered"]
    u = row["underfiltered"]
    return 1.0 if f + u == 0.0 else f / (f + u)


def f_measure(row):
    p = precision(row)
    r = recall(row)
    return 1.0 if p + r == 0 else 2.0 * p * r / (p + r)


def build_metrics(row):
    return {
        "good_events": row["good_events"],
        "underfiltered": row["underfiltered"],
        "overfiltered": row["overfiltered"],
        "filtered": row["filtered"],
        "precision": precision(row),
        "recall": recall(row),
        "f_measure": f_measure(row),
    }


class BscountCluster(object):
    SOLOMON_API_URL = "http://api.solomon.search.yandex.net"

    def __init__(self, cluster_name, solomon_api_token):
        from solomon import PushApiReporter, OAuthProvider

        self.cluster_name = cluster_name
        self.pusher = PushApiReporter(
            project='bscount',
            cluster=cluster_name,
            service=cluster_name,
            url=self.SOLOMON_API_URL,
            auth_provider=OAuthProvider(solomon_api_token)
        )
        self.metrics = defaultdict(lambda: {
            'good_events': 0,
            'underfiltered': 0,
            'overfiltered': 0,
            'filtered': 0,
        })

    def update_metrics(self, row):
        metrics = self.metrics[row['role'], row['countertype']]
        for key in metrics:
            metrics[key] += row[key]

    def build_metrics(self):
        return {
            key: build_metrics(metrics)
            for key, metrics in self.metrics.items()
        }


class BscountQualityComputation(sdk2.Task):
    SAMOGON_CLUSTER_API_URL = 'http://clusterapi-{name}.n.yandex-team.ru/hosts'

    REQUEST_TEMPLATE = '''
        select
            b.hostid as hostid,
            b.role as role,
            e.countertype as countertype,
            count_if(e.fraudbits == 0 and b.fraudbits == 0) as good_events,
            count_if(e.fraudbits != 0 and b.fraudbits == 0) as underfiltered,
            count_if(e.fraudbits == 0 and b.fraudbits != 0) as overfiltered,
            count_if(e.fraudbits != 0 and b.fraudbits != 0) as filtered
        from (
            select
                cast(logid as Uint64) as logid,
                cast(countertype as Uint64) as countertype,
                cast(fraudbits as Uint64) as fraudbits
            from
                `logs/bs-chevent-log/1h/{process_hour}`
        ) as e
        inner join (
            select
                HostID as hostid,
                if(IsSearchPage = 1, 'search', 'partner') as role,
                LogID as logid,
                FraudBits as fraudbits
            from
                `logs/bscount-event-log/1h/{process_hour}`
        ) as b
        on e.logid == b.logid
        group by
            b.hostid,
            b.role,
            e.countertype;
    '''

    class Requirements(sdk2.Requirements):
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Parameters):
        yql_token = sdk2.parameters.String("YQL Token vault name", default="yql_token", required=True)
        solomon_api_token = sdk2.parameters.String("Solomon API token vault name", default="solomon_api_token", required=True)
        process_delay = sdk2.parameters.Integer("Process delay (in hours)", default=2, required=True)

    def _create_client(self):
        from yql.api.v1.client import YqlClient

        token = sdk2.Vault.data(self.author, self.Parameters.yql_token)
        return YqlClient(db='hahn', token=token)

    def _select_hour(self):
        process_hour = datetime.utcnow() - timedelta(hours=self.Parameters.process_delay)
        process_hour = process_hour.replace(minute=0, second=0)
        process_hour = process_hour.strftime('%Y-%m-%dT%H:%M:%S')
        setattr(self.Context, 'process_hour', process_hour)

    def _form_query(self):
        process_hour = getattr(self.Context, 'process_hour')
        query = self.REQUEST_TEMPLATE.format(process_hour=process_hour)
        return textwrap.dedent(query)

    def _run_yql_task(self):
        def callback(request):
            setattr(self.Context, 'operation_id', request.operation_id)
        request = self.yql_client.query(self._form_query(), syntax_version=1)
        request.run(pre_start_callback=callback)

    def _wait_yql(self):
        from yql.client.operation import YqlOperationStatusRequest

        operation_id = getattr(self.Context, 'operation_id')
        status = YqlOperationStatusRequest(operation_id)
        status.run()
        if status.status in status.IN_PROGRESS_STATUSES:
            raise sdk2.WaitTime(60)
        if status.status != 'COMPLETED':
            raise TaskFailure("YQL query failed")

    def _get_results(self):
        from yql.client.operation import YqlOperationResultsRequest

        operation_id = getattr(self.Context, 'operation_id')
        results = YqlOperationResultsRequest(operation_id)
        results.run()

        table = next(iter(results.get_results()))
        columns = []
        for column_name, column_type in table.columns:
            columns.append(column_name)

        table.fetch_full_data()
        data = []
        for row in table.rows:
            data.append(dict(zip(columns, map(try_float, row))))
        return data

    def _get_cluster_mapping(self, cluster_names):
        result = {}
        for name in cluster_names:
            response = requests.get(self.SAMOGON_CLUSTER_API_URL.format(name=name.replace('_', '')))
            result.update({host: name for host in response.json()['value']})
        return result

    def _push_results(self):
        cluster_names = ('bscount', 'bscount_1', 'bscount_2', 'bscount_3', 'bscount_4')
        cluster_mapping = self._get_cluster_mapping(cluster_names)
        solomon_api_token = sdk2.Vault.data(self.author, self.Parameters.solomon_api_token)
        clusters = {name: BscountCluster(name, solomon_api_token) for name in cluster_names}

        process_hour = datetime.strptime(getattr(self.Context, 'process_hour'), '%Y-%m-%dT%H:%M:%S')

        for row in self._get_results():
            metrics = build_metrics(row)
            labels = [{'host': row['hostid'], 'role': row['role'], 'partition': row['countertype']}] * len(metrics)
            if row['hostid'] not in cluster_mapping:
                continue
            cluster = clusters[cluster_mapping[row['hostid']]]
            cluster.update_metrics(row)
            cluster.pusher.set_value(metrics.keys(), metrics.values(), labels, ts_datetime=process_hour)

        for cluster in clusters.values():
            for (role, countertype), metrics in cluster.build_metrics().items():
                labels = [{'host': 'any', 'role': role, 'partition': countertype}] * len(metrics)
                cluster.pusher.set_value(metrics.keys(), metrics.values(), labels, ts_datetime=process_hour)

    def on_execute(self):
        self.yql_client = self._create_client()

        with self.memoize_stage.select_hour:
            self._select_hour()

        with self.memoize_stage.run_yql_task(commit_on_entrance=False):
            self._run_yql_task()

        self._wait_yql()

        with self.memoize_stage.push_results:
            self._push_results()
