import functools
import json
import logging
import time
from datetime import datetime, timedelta

import requests

from mail_search_webtools.settings import SO_FRAUD_BACKEND_HOST

logger = logging.getLogger(__name__)


def datetime_to_str(dt):
    return dt.strftime('%x %X')


def field_to_datetime(data, field, dst_field=None):
    if field in data:
        data[dst_field or field] = datetime_to_str(datetime.fromtimestamp(int(data[field]) // 1000))


def get_transactions(channel_uri, type, limit, query):
    channel, sub_channel = channel_uri.split('/')
    response = requests.post(SO_FRAUD_BACKEND_HOST + '/get_transactions', json={
        "channel": channel,
        "sub_channel": sub_channel,
        "type": type,
        "limit": limit,
        "query": query
    })

    response.raise_for_status()

    response = response.json()

    transactions = response['transactions']
    for transaction in transactions:
        if 'currency' in transaction:
            transaction['currency'] = transaction['currency']
        field_to_datetime(transaction, 'txn_timestamp')
        field_to_datetime(transaction, 'txn_status_timestamp')
        transaction['id'] = transaction['id'][4:]
        if 'amount' in transaction:
            transaction['amount'] = int(transaction['amount']) / 100

    return response


def get_transactions_by_channels(channels, type, limit, query):
    transactions_dict = {}
    aggregates_dict = {}
    datum = {
        "transactions": [],
        "aggregates": []
    }

    for ch in channels:
        response = get_transactions(ch, type, limit, query)

        transactions_dict.update({t["id"]: t for t in response['transactions']})
        aggregates_dict.update({t["id"]: t for t in response['aggregates']})

    datum['transactions'] = list(transactions_dict.values())
    datum['aggregates'] = list(aggregates_dict.values())

    return datum


def get_lists(channel_uri, list_name, value):
    channel, sub_channel = channel_uri.split('/')
    response = requests.post(SO_FRAUD_BACKEND_HOST + '/get_list', json={
        "channel": channel,
        "sub_channel": sub_channel,
        "list_name": list_name,
        "value": value,
        "type": "FAST_LIST",
    })

    if response.status_code != 200:
        raise Exception(response.content)

    items = response.json()
    for item in items:
        item["txn_timestamp"] = int(item["txn_timestamp"])
        item["txn_status_timestamp"] = int(item["txn_status_timestamp"])
        field_to_datetime(item, "txn_timestamp", "from")
        field_to_datetime(item, "txn_status_timestamp", "to")
        item["author"] = item["uid"]
        item["list_name"] = item["card"]
        item["reason"] = item.get("txn_afs_reason")

    return items


def get_lists_by_channels(channels, list_name, value):
    lists = []

    for ch in channels:
        lists += get_lists(ch, list_name, value)

    return lists


def timed_cache(ttl):
    def decorator(func):
        cache = {}
        last_updates = {}

        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            now = datetime.now()
            key = json.dumps(args) + json.dumps(kwargs)
            last_update = last_updates.get(key)
            if last_update is None or last_update + ttl < now:
                value = func(*args, **kwargs)
                cache[key] = value
                last_updates[key] = now
            return cache[key]

        return wrapper

    return decorator


@timed_cache(timedelta(minutes=5))
def get_lists_names(channel, sub_channel):
    response = requests.post(SO_FRAUD_BACKEND_HOST + '/get_lists_names', timeout=0.3, json={
        "channel": channel,
        "sub_channel": sub_channel,
    })

    try:
        response.raise_for_status()
    except Exception as e:
        logger.error(e)
        return []

    return response.json()


def get_all_lists_names(channel_uris):
    lists_names_by_channels = {}
    for channel_uri in channel_uris:
        channel, sub_channel = channel_uri.split('/')
        try:
            lists_names_by_channels[channel_uri] = get_lists_names(channel, sub_channel)
        except Exception as e:
            logging.warning(e)
    return lists_names_by_channels


def datetime_to_milliseconds(dt):
    return int(time.mktime(dt.timetuple())) * 1000


def now_milliseconds():
    return datetime_to_milliseconds(datetime.now())


DELTAS = (
    (timedelta(days=28).total_seconds() * 1000, "month"),
    (timedelta(weeks=1).total_seconds() * 1000, "week"),
    (timedelta(days=1).total_seconds() * 1000, "day"),
    (timedelta(hours=1).total_seconds() * 1000, "hour"),
    (timedelta(minutes=1).total_seconds() * 1000, "minute"),
)


def make_timestamp_query(since, until):
    query = ""

    since = datetime_to_milliseconds(since)
    until = datetime_to_milliseconds(until)

    since -= since % DELTAS[-1][0]
    until -= since % DELTAS[-1][0]
    until += DELTAS[-1][0]

    while since <= until:
        for delta, name in DELTAS:
            if since % delta == 0 and since + delta <= until:
                break
        else:
            delta, name = DELTAS[-1]

        if query:
            query += " OR "

        query += "txn_" + name + ":" + str(since)

        since += delta

    if query:
        return "(" + query + ")"
    else:
        return "txn_timestamp:[" + str(since) + " TO " + str(until) + "]"
