# -*- coding: utf-8 -*-

from flask import redirect, url_for
from yt.yson import YsonEntity
from collections import defaultdict
from mongo import mongo_db
from infra.yp.account_estimation import Account

import base64
import datetime
import logging
import numpy as np
import pandas as pd
import pymongo
import random
import requests
import retrying
import time

import abc_api
import startrek_api

from constants import (
    ABC_REMAPPING,
    DISPENSER_REMAPPING,
    ABC_API_URL,
    LOG_FORMATTER,
    GROUP_FULL_CARD_TEMPLATE,
    GROUP_SHORT_CARD_TEMPLATE,
    VALID_MASTER_GROUPS,
    DISPENSER_URL,
    GENCFG_GROUPS
)


QUOTA_PATH = "/resources/request/"
QUOTA_NUMBERS_KEYS = ["cpu", "memory", "hdd", "ssd", "ip4", "io_hdd", "io_ssd", "net_bandwidth"]
REJECT_TIME_DELTA = datetime.timedelta(minutes=1, seconds=40)


logger = logging.getLogger(__name__)


def retry_until_success(f):
    def retry(*args, **kwargs):
        flag = False
        while not flag:
            try:
                f(*args, **kwargs)
                flag = True
            except:
                time.sleep(5 + random.random() * 10)
    return retry


@retry_until_success
def create_yp_quota_request_for_abc_service_until_success(session, service_id, quota):
    create_yp_quota_request_for_abc_service(session, service_id, quota)


@retrying.retry(stop_max_attempt_number=2)
def create_yp_quota_request_for_abc_service_with_retries(session, service_id, quota):
    create_yp_quota_request_for_abc_service(session, service_id, quota)


def create_yp_quota_request_for_abc_service(session, service_id, quota):
    data = {'resource_type': 96}
    data['service'] = service_id
    quota_data = {"comment": "yp quota"}
    quota_data.update(quota)
    data['data'] = quota_data
    response = session.post(ABC_API_URL + QUOTA_PATH, json=data)
    response.raise_for_status()


def send_commentary_to_ticket(ticket_key, quota, groups):
    rows = []
    for item in quota:
        rows.append([
            item['location'],
            item['cpu'],
            str(item['memory']) + "Gb",
            str(item['hdd']) + "Tb",
            str(item['ssd']) + "Tb",
        ])
    table = pd.DataFrame(rows, columns=["DC", "Cpu", "Memory", "Hdd", "Ssd"])
    if rows:
        table = pretty_print_pandas_table(table, args=None)
    comments = ["Данная квота: <# {} #>".format(table.to_html()), "была выдана в счет групп {}".format(','.join(groups))]
    startrek_api.add_comment_to_the_ticket(ticket_key, "\n".join(comments))
    startrek_api.close_ticket(ticket_key, "fixed")


def create_not_existing_record(collection, name, default_value):
    if collection.find({name: {'$exists': True}}).count() == 0:
        collection.insert({name: default_value})


def increase_unistat_signal_value(name, value):
    collection = mongo_db["unistat_signals"]
    create_not_existing_record(collection, name, {"cur_value": 0.0})
    cursor = collection.find_one({name: {'$exists': True}})
    collection.update({"_id": cursor["_id"]}, {"$set": {name: {"cur_value": cursor[name]["cur_value"] + value}}})
    return cursor[name]["cur_value"] + value


def pretty_print_pandas_table(table, args, total_stats_columns=None):
    if table.shape[0] == 0:
        logger.warn('Nothing found with these parameters: {}')
        return

    if args is not None and args.sort_by is not None:
        table.sort_values(args.sort_by.split(','), ascending=args.sort_order == 'asc', inplace=True)

    total = None
    if total_stats_columns is not None:
        total = table[total_stats_columns].sum()

    if args is not None and args.output_limit > 0:
        table = table.head(args.output_limit)

    table.is_copy = False

    if total is not None:
        table.loc['total'] = total
        table[total_stats_columns] = table[total_stats_columns].applymap(np.int64)

    return table.applymap(lambda x: '' if isinstance(x, YsonEntity) else x)


def logging_wrapper(time_flag=True, dump_flag=False):
    def wrap(func):
        "This decorator dumps out the arguments passed to a function before calling it"
        argnames = func.func_code.co_varnames[:func.func_code.co_argcount]
        fname = func.func_name

        def echo_func(*args, **kwargs):
            time_str = "Time: '{}'".format(datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f"))
            function_descr = "Function dump: (" + ', '.join(
                '%s=%r' % entry
                for entry in zip(argnames, args[:len(argnames)]) +
                [("args", list(args[len(argnames):]))] +
                [("kwargs", kwargs)]) + ")"
            data = ["Start", "Func name: '{}'".format(fname)]
            if time_flag:
                data.append(time_str)
            if dump_flag:
                data.append(function_descr)
            logger.info('. '.join(data))
            res = func(*args, **kwargs)
            data[0] = "End"
            logger.info('. '.join(data))
            return res

        return echo_func
    return wrap


def get_ticket_for_banned_service(service_id):
    banned_abc_services = mongo_db["banned_abc_services"]
    data = banned_abc_services.find_one({"_id": str(service_id)})
    return data["data"]["ticket"]


def is_abc_service_banned(service_id):
    banned_abc_services = mongo_db["banned_abc_services"]
    return banned_abc_services.find({"_id": str(service_id)}).count() > 0


def handle_all_tickets_from_queue_with_filter(session, filter_json, callback, *args):
    tickets = startrek_api.get_all_issues_with_filter(filter_json)
    for ticket in tickets:
        callback(session, ticket, *args)


def get_all_dumped_groups():
    gencfg_groups_collection = get_collection_by_type(GENCFG_GROUPS)
    groups = set()
    for record in gencfg_groups_collection.find():
        for group in record["groups"]:
            groups.add(group)

    return groups


# mongo keys can't contain dots
def encode_key_for_mongo(key):
    return key.replace(".", "\u002e")


def decode_key_for_mongo(key):
    return key.replace("\u002e", ".")


@retrying.retry(stop_max_attempt_number=1)
def check_guest_group(group):
    session = requests.Session()
    response = session.get(GROUP_FULL_CARD_TEMPLATE.format(group), verify=False)
    response.raise_for_status()
    return response.json().get('properties', {}).get('created_from_portovm_group', None)


def convert_string_to_date(time_string, format_string):
    return datetime.datetime.strptime(time_string, format_string)


def convert_date_to_string(time, format_string):
    return datetime.datetime.strftime(time, format_string)


def get_master_group(group):
    session = requests.Session()
    response = session.get(GROUP_SHORT_CARD_TEMPLATE.format(group), verify=False)
    if response.ok:
        return response.json().get('master', None)
    return None


def is_valid_slave(group):
    master_group = get_master_group(group)
    return master_group in VALID_MASTER_GROUPS


def add_record_fields(yandex_login, status, request_type):
    return {
        'yandex_login': yandex_login,
        'lastUpdated': datetime.datetime.now(),
        'createdTime': datetime.datetime.now(),
        'state': status,
        'showState': status,
        'type': request_type
    }


def get_user_history(yandex_login, request_type):
    render_data = []
    requests_collection = mongo_db["requests_queue"]
    completed_collection = mongo_db["completed_requests"]

    def handle_collections(collections):
        records = []
        for collection in collections:
            cursor = collection.find(
                {
                    'type': request_type,
                    'yandex_login': yandex_login
                },
                sort=[("lastUpdated", -1)])
            for record in cursor:
                records.append(record)
        for i, record in enumerate(records):
            render_data.append([len(records) - i - 1, record])

    handle_collections([requests_collection, completed_collection])
    return render_data


# TODO: Comment
def create_new_request_or_display_user_history(request,
                                               request_type,
                                               redirect_url,
                                               form,
                                               yandex_login,
                                               status,
                                               get_record_from_form_callback):
    requests_collection = mongo_db["requests_queue"]

    if form.validate_on_submit():
        record = get_record_from_form_callback(form)
        record.update(add_record_fields(yandex_login, status, request_type))
        requests_collection.insert(record)
        return redirect(url_for(redirect_url) + '?record_id={}'.format(record['_id'])), []
    elif request.method == 'POST':
        logger.warn("Yandex login: {}. Expect GET request or POST with valid from".format(yandex_login))

    return None, get_user_history(yandex_login, request_type)


def get_quota_table_by_accounts(account_per_dc, full_cores=False, distributors=None):
    rows = []
    for dc, resources in account_per_dc.iteritems():
        if resources["cpu"] > 0 or resources["memory"] > 0 or resources["hdd"] > 0 or resources["ssd"] > 0:
            if "io_ssd" not in resources:
                resources["io_ssd"] = 0
            if "io_hdd" not in resources:
                resources["io_hdd"] = 0
            if "net_bandwidth" not in resources:
                resources["net_bandwidth"] = 0
            row = [
                dc.upper(),
                resources["cpu"] / 1000.0 if not full_cores else resources["cpu"],
                str(resources["memory"]) + "Gb",
                str(resources["hdd"]) + "Tb",
                str(resources["ssd"]) + "Tb",
                str(resources["io_ssd"] if full_cores else round(resources["io_ssd"] / 1000.0)) + "MB/s",
                str(resources["io_hdd"] if full_cores else round(resources["io_hdd"] / 1000.0)) + "MB/s",
                str(resources["net_bandwidth"] if full_cores else round(resources["net_bandwidth"] / 1000.0)) + "MB/s",
            ]
            if distributors:
                row.append(distributors.get(dc, "abc"))
            rows.append(row)
    columns = ['DC', 'Cpu', 'Memory', 'Hdd', 'Ssd', 'IO_SSD', 'IO_HDD', 'NET_BANDWIDTH']
    if distributors:
        columns.append('Distributor')
    table = pd.DataFrame(rows, columns=columns)
    if rows:
        table = pretty_print_pandas_table(table, args=None)
    return table


def get_balancer_quota_table(balancer_quota):
    rows = []
    columns = ['BALANCER_MODE', 'MAN', 'SAS', 'VLA', 'IVA', 'MYT']
    balancer_quota.int_all_values()
    json_description = balancer_quota.to_json()
    for item in json_description:
        row = []
        row.append(item['balancer_mode'])
        for dc_column in columns[1:]:
            row.append(item[dc_column.lower()])
        rows.append(row)

    table = pd.DataFrame(rows, columns=columns)
    if rows:
        table = pretty_print_pandas_table(table, args=None)
    return table


def get_service_dump_by_name(service_info, collection_name):
    if collection_name in [ABC_REMAPPING, DISPENSER_REMAPPING]:
        service_info = service_info.strip('/').split('/')[-1]

    service_remapping_collection = get_collection_by_type(collection_name)
    service_remapping = defaultdict(dict)
    for record in service_remapping_collection.find():
        service_remapping[record["_id"]].update(record)

    if service_info in service_remapping:
        return service_remapping[service_info]

    base64_service_info = base64.b64encode(service_info.encode('utf-8'))
    if base64_service_info in service_remapping:
        return service_remapping[base64_service_info]

    logger.error("Not valid {} service information: {}".format(collection_name, service_info))
    return {}


def convert_account_to_dispenser_resources(account):
    resource_type_to_unit_mapping = {
        "cpu": "PERMILLE_CORES",
        "memory": "BYTE",
        "ssd": "BYTE",
        "hdd": "BYTE"
    }
    resources = []
    for resource_type, unit in resource_type_to_unit_mapping.iteritems():
        resources.append({
            "amount": {
                "value": account[resource_type],
                "unit": unit
            },
            "resource_key": resource_type
        })
    return resources


def convert_resources_to_minimal_units(resources):
    return {
        'cpu': resources['cpu'],
        'memory': resources['memory'] * 1024.0**3,
        'ssd': resources['ssd'] * 1024.0**4,
        'hdd': resources['hdd'] * 1024.0**4
    }


def dispenser_project_url(project_key):
    return DISPENSER_URL + "clouds/projects/{}".format(project_key)


def update_collection_record(id, updating_fields=None, finish_flag=False):
    if updating_fields is None:
        updating_fields = {}
    collection = mongo_db["requests_queue"]
    updating_fields["lastUpdated"] = datetime.datetime.now()
    if not finish_flag:
        collection.update({"_id": id}, {"$set": updating_fields})
    else:
        completed_collection = mongo_db["completed_requests"]
        record = collection.find_one({"_id": id})
        record.update(updating_fields)

        try:
            completed_collection.insert(record)
        except pymongo.errors.DuplicateKeyError:
            pass

        collection.remove({"_id": id})


def get_quota_value(value):
    return float(str(value).replace(",", ".")) if value != "None" else 0.0


def reject_duplicates_for_service(service_id, requests):
    if len(requests) == 0:
        return

    resources_by_dc = defaultdict(list)
    for item in requests:
        attrs = item['resource']['attributes']
        info = {}
        for attr in attrs:
            if attr["name"] in QUOTA_NUMBERS_KEYS:
                info[attr["name"]] = get_quota_value(attr["value"])
            else:
                info[attr["name"]] = attr["value"]
        account = Account(info["cpu"],
                          info["memory"],
                          info["hdd"],
                          info["ssd"],
                          info.get("ip4", 0),
                          io_ssd=info.get("io_sdd", 0),
                          io_hdd=info.get("io_hdd", 0),
                          net_bandwidth=info.get("net_bandwidth", 0))

        time = datetime.datetime.now()
        for formt_str in ["%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S.%fZ"]:
            try:
                time = datetime.datetime.strptime(item["modified_at"].split('+')[0], formt_str)
                break
            except Exception:
                pass
        resources_by_dc[(info['location'], info['segment'])].append(
            [account, time,
             info.get('gencfg-groups', ""), item['id']])

    for value in resources_by_dc.values():
        value = sorted(value, key=lambda x: x[1])

        if len(value) < 2:
            continue

        for ind in range(1, len(value)):
            time_delta = value[ind][1] - value[ind - 1][1]
            if time_delta < datetime.timedelta(0):
                time_delta = -time_delta

            if time_delta < REJECT_TIME_DELTA and \
               value[ind][0].__dict__ == value[ind - 1][0].__dict__ and \
               value[ind][2] == value[ind - 1][2]:
                abc_api.deprive_quota_request(value[ind][-1])


def get_collection_by_type(type):
    return mongo_db["dumped_info_" + type]
