# -*- coding: utf-8 -*-
import os
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from threading import Thread
from Queue import Queue
import numpy as np
from collections import defaultdict
from sklearn.metrics import roc_auc_score
from simplejson import JSONDecodeError
from yt.wrapper import ypath_split

from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.stability.mini_batch_monitoring.mini_batch_log_table import MiniBatchLogTable
from datacloud.stability.mini_batch_monitoring.mini_batch_to_solomon import post_results

logger = get_basic_logger(__name__)

INTERNAL_PARTNER = 'internal'
API_URL = 'https://datacloud.yandex.net/v1/accounts/{}/scores'.format(INTERNAL_PARTNER)
INTERNAL_TOKEN_ENV = 'INTERNAL_SCORE_API_TOKEN'
N_CONCURRENT_THREADS = 2
MINI_BATCH_DEFAULT_SIZE = 10000


def requests_retry_session(retries=3, backoff_factor=0.3, status_forcelist=(500, 502, 504),
                           session=None):
    session = session or requests.Session()
    retry = Retry(total=retries, read=retries, connect=retries, backoff_factor=backoff_factor,
                  status_forcelist=status_forcelist)
    adapter = HTTPAdapter(max_retries=retry)
    session.mount('http://', adapter)
    session.mount('https://', adapter)
    return session


def thread_work(queue, score_values, score_name, internal_token):
    def _inner():
        while True:
            external_id, id_vals = queue.get()
            score_value = do_request(score_name=score_name, internal_token=internal_token,
                                     id_vals=id_vals)
            score_values.append({
                'external_id': external_id,
                'score_value': score_value
            })
            queue.task_done()

    return _inner


def do_request(score_name, internal_token, id_vals):
    params = {
        'scores': [{'score_name': score_name}],
        'user_ids': {'emails': [{'id_value': id_val for id_val in id_vals}]}
    }
    headers = {'Authorization': internal_token}

    try:
        resp = requests_retry_session().post(API_URL, headers=headers, json=params)
        if resp.status_code != 200:
            try:
                logger.error(resp.json())
            except JSONDecodeError:
                logger.error('Bad response json, status: {}'.format(resp.status_code))
            return None
        result = resp.json()['scores'][0]
        if result['has_score']:
            return result['score_value']
    except (requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout):
        pass

    return None


def perfrom_mini_batch_monitoring(yt_client, partner_id, score_name, score_date, batch_table,
                                  max_batch_size=MINI_BATCH_DEFAULT_SIZE, internal_token=None):
    internal_token = internal_token or os.getenv(INTERNAL_TOKEN_ENV, None)
    assert internal_token is not None, 'Internal token was not found!'

    external_id2target = dict()
    external_id2id_vals = defaultdict(list)
    logger.info('reading batch')
    for row in yt_client.read_table(batch_table):
        external_id = row['external_id']
        external_id2target[external_id] = int(row['target'])
        external_id2id_vals[external_id].append(row['id_value'])

        if len(external_id2target) >= max_batch_size:
            break

    batch_size = len(external_id2target)
    logger.info('batch size is {}'.format(batch_size))

    queue = Queue(2 * N_CONCURRENT_THREADS)
    score_values = []
    for i in range(N_CONCURRENT_THREADS):
        logger.info('thread {}'.format(i))
        thread = Thread(target=thread_work(queue, score_values, score_name, internal_token))
        thread.daemon = True
        thread.start()

    logger.info('fire!')
    for external_id, id_vals in external_id2id_vals.iteritems():
        queue.put((external_id, id_vals))
    queue.join()

    preds, ground = [], []
    for score_value in score_values:
        preds.append(score_value['score_value'])
        ground.append(external_id2target[score_value['external_id']])

    preds = np.array(preds, dtype=float)
    ground = np.array(ground, dtype=float)

    not_nan_preds = preds[~np.isnan(preds)]
    not_nan_ground = ground[~np.isnan(preds)]

    hit = float(len(not_nan_preds)) / len(preds)
    roc_auc = roc_auc_score(not_nan_ground, not_nan_preds)

    results = dict(partner_id=partner_id, score_name=score_name, score_date=score_date,
                   batch_name=ypath_split(batch_table)[-1], AUC=roc_auc, hit=hit,
                   batch_size=batch_size)
    MiniBatchLogTable().add_record(**results)
    post_results(**results)

    return hit, roc_auc
