import functools
import json
import logging
import re
import time
import typing
from datetime import timedelta

import aiohttp
from asgiref.sync import async_to_sync
from django.conf import settings
from django.contrib.auth import get_user_model
from django.shortcuts import get_object_or_404
from django.utils.datetime_safe import datetime, strftime
from tvm2 import TVM2
from tvmauth import BlackboxTvmId

from mail.so.daemons.antifraud.antifraud_django.sofraud.library.role import read_roles_from_config

ROLE_MAP = settings.CHANNELS_CONFIG and read_roles_from_config(settings.CHANNELS_CONFIG) or {}
ALL_ROLES = sorted(functools.reduce(lambda s1, s2: s1 | s2, ROLE_MAP.values(), set()))

GROUP_PREFIX = "af."
SET_PREFIX = "set."
GROUP_SET_PREFIX = GROUP_PREFIX + SET_PREFIX

logger = logging.getLogger(__name__)


def backend_session(project) -> aiohttp.ClientSession:
    tvm = TVM2(
        client_id=settings.TVM_CLIENT_ID,
        secret=settings.TVM_SECRET,
        blackbox_client=BlackboxTvmId.Test if settings.DEBUG_AUTH else BlackboxTvmId.Prod,
        allowed_clients=tuple(),
        destinations=(settings.SO_FRAUD_BACKEND_TVM_ID[project],),
    )
    return aiohttp.ClientSession(headers={
        'X-Ya-Service-Ticket': tvm.get_service_ticket(settings.SO_FRAUD_BACKEND_TVM_ID[project])
    })


FORMAT = "%Y-%m-%d %H:%M:%S"


def datetime_to_str(dt: datetime):
    return strftime(dt, FORMAT)


def datetime_from_str(string: str):
    return datetime.strptime(string, FORMAT)


def field_to_datetime(data, field, dst_field=None):
    if field in data:
        data[dst_field or field] = datetime.fromtimestamp(int(data[field]) // 1000)


def recursive_fix_amounts(value):
    t = type(value)
    if t in (int, float, str):
        try:
            return float(value) / 100
        except Exception as e:
            logger.warning(f"cannot parse float from '{value}': {e}")
            return value

    if t is dict:
        return {k: recursive_fix_amounts(v) for k, v in value.items()}

    if t == list:
        return [recursive_fix_amounts(v) for v in value]

    raise Exception("unsupported type " + str(t))


def recurse_fix_time_and_amounts(value: typing.Any) -> typing.Any:
    if isinstance(value, dict):
        fixed = {}
        for k, v in value.items():
            if k in {"data", "trust"}:
                fixed[k] = v
            elif re.match(r".*(amount).*", k) or re.match(r".*(amnt).*", k):
                fixed[k] = recursive_fix_amounts(v)
            elif re.match(r".*(timestamp)|(first_acquire)|(last_acquire).*", k):
                fixed[k] = datetime.fromtimestamp(int(v) // 1000)
            else:
                fixed[k] = recurse_fix_time_and_amounts(v)
        return fixed
    elif isinstance(value, list):
        return [recurse_fix_time_and_amounts(v) for v in value]
    else:
        return value


async def get_transactions(project: str,
                           session: aiohttp.ClientSession,
                           channel_uri: str,
                           type: str,
                           limit: int,
                           query: str,
                           prefix: int = None) -> dict:
    channel, sub_channel = split_channel_uri(channel_uri)

    request_data = {
        "channel": channel,
        "sub_channel": sub_channel,
        "type": type,
        "limit": limit,
        "query": query,
    }

    if prefix is not None:
        request_data["prefix"] = prefix

    logger.debug(f"get_transactions:request_data:{request_data}")

    response = await session.post(settings.SO_FRAUD_BACKEND_HOST[project] + '/get_transactions',
                                    ssl=settings.SSL_CONTEXT,
                                    json=request_data)

    if 400 <= response.status:
        raise Exception(f"{response.status} {await response.text()}")

    response = recurse_fix_time_and_amounts(await response.json())

    transactions = response['transactions']
    logger.debug(f"get_transactions:response_data:{transactions}")
    return transactions


async def get_transactions_by_channels(project: str,
                                       channels: typing.List[str],
                                       type: str,
                                       limit: int,
                                       query: str) -> typing.Tuple[dict, list]:
    transactions_dict = {}
    aggregates_dict = {}
    errors = []

    async with backend_session(project) as session:
        for ch in channels:
            try:
                transactions = await get_transactions(project, session, ch, type, limit, query)

                for transaction in transactions:
                    transaction_type = transaction["type"]
                    transaction_id = transaction["id"]
                    if transaction_type in {"MAIN", "SAVE"}:
                        transactions_dict[transaction_id] = transaction
                    elif transaction_type == "AGGRS":
                        aggregates_dict[transaction_id] = transaction
                    else:
                        raise Exception(f"unexpected type:{transaction_type}")
            except Exception as e:
                logger.exception("get_transactions_by_channels", exc_info=e)
                errors.append((ch, e))

    datum = {
        'transactions': list(transactions_dict.values()),
        'aggregates': list(aggregates_dict.values())
    }

    return datum, errors


@functools.lru_cache
@async_to_sync
async def get_transactions_by_channels_sync(project: str,
                                            channels: typing.List[str],
                                            type: str,
                                            limit: int,
                                            query: str) -> typing.Tuple[dict, list]:
    return await get_transactions_by_channels(project, channels, type, limit, query)


async def get_verification_levels(project: str,
                                  query: str,
                                  limit: int) -> typing.List[dict]:
    async with backend_session(project) as session:
        response = await session.post(
            settings.SO_FRAUD_BACKEND_HOST[project] + '/get_verification_levels',
            ssl=settings.SSL_CONTEXT,
            json={
                "query": query,
                "limit": limit
            })
        response.raise_for_status()

        return recurse_fix_time_and_amounts(await response.json())


@async_to_sync
async def get_verification_levels_sync(project: str,
                                       query: str,
                                       limit: int) -> typing.List[dict]:
    return await get_verification_levels(project, query, limit)


async def get_lists(project: str,
                    session: aiohttp.ClientSession,
                    channel_uri,
                    list_name,
                    value,
                    limit):
    channel, sub_channel = split_channel_uri(channel_uri)
    request_data = {
        "channel": channel,
        "sub_channel": sub_channel,
        "list_name": list_name,
        "value": value,
        "limit": limit,
    }
    logger.debug(f"get_lists:request_data:{locals()}")
    response = await session.post(settings.SO_FRAUD_BACKEND_HOST[project] + '/get_list',
                                  ssl=settings.SSL_CONTEXT,
                                  json=request_data)

    response.raise_for_status()

    items = await response.json()
    for item in items:
        field_to_datetime(item, "from")
        field_to_datetime(item, "to")

    return items


@async_to_sync
async def get_lists_sync(project: str, channel_uri, list_name, value, limit):
    async with backend_session(project) as session:
        return await get_lists(project, session, channel_uri, list_name, value, limit)


def timed_cache(ttl):
    def decorator(func):
        cache = {}
        last_updates = {}

        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            now = datetime.now()
            key = json.dumps(args) + json.dumps(kwargs)
            last_update = last_updates.get(key)
            if last_update is None or last_update + ttl < now:
                value = func(*args, **kwargs)
                cache[key] = value
                last_updates[key] = now
            return cache[key]

        return wrapper

    return decorator


@timed_cache(timedelta(minutes=5))
async def get_lists_names(project: str,
                          session: aiohttp.ClientSession,
                          channel,
                          sub_channel):
    response = await session.post(settings.SO_FRAUD_BACKEND_HOST[project] + '/get_lists_names',
                                  timeout=0.3,
                                  ssl=settings.SSL_CONTEXT,
                                  json={
                                      "channel": channel,
                                      "sub_channel": sub_channel})

    try:
        response.raise_for_status()
    except Exception as e:
        logger.error(e)
        return []

    return await response.json()


async def get_all_lists_names(project: str, channel_uris):
    lists_names_by_channels = {}

    async with backend_session(project) as session:
        for channel_uri in channel_uris:
            channel, sub_channel = channel_uri.split('/')
            try:
                lists_names_by_channels[channel_uri] = await get_lists_names(project,
                                                                             session,
                                                                             channel,
                                                                             sub_channel)
            except Exception as e:
                logging.warning(e)
    return lists_names_by_channels


def datetime_to_milliseconds(dt):
    return int(time.mktime(dt.timetuple())) * 1000


def now_milliseconds():
    return datetime_to_milliseconds(datetime.now())


TIME_PERIODS = (
    (timedelta(minutes=1), "minute"),
    (timedelta(hours=1), "hour"),
    (timedelta(days=1), "day"),
    (timedelta(weeks=1), "week"),
    (timedelta(days=28), "month"),
)


def make_timestamp_query(since: datetime, until: datetime):
    delta = until - since
    for period, name in reversed(TIME_PERIODS):
        if delta > period:
            break
    else:
        name = TIME_PERIODS[0][1]
    # 9999999999999
    return f"txn_{name}:[{datetime_to_milliseconds(since)} TO {datetime_to_milliseconds(until)}]"


def get_user_channels(user):
    if user.is_superuser:
        return ALL_ROLES

    channels = set()

    for g in user.groups.filter(name__startswith=GROUP_SET_PREFIX):
        name: str = g.name[len(GROUP_SET_PREFIX):]

        channel_set = ROLE_MAP.get(name)
        if not channel_set:
            logger.warning(f"not existing set: {locals()}")
            continue

        for channel in channel_set:
            channels.add(channel)

    return sorted(channels)


def split_channel_uri(channel_uri: str) -> typing.Tuple[str, typing.Optional[str]]:
    try:
        parts = channel_uri.split('/')
        return parts[0], parts[1]
    except IndexError:
        return channel_uri, None
    except ValueError:
        return channel_uri, None


async def update_list(project: str,
                      session: aiohttp.ClientSession,
                      channel,
                      sub_channel,
                      list_name,
                      author,
                      from_timestamp,
                      to_timestamp,
                      items,
                      reason):
    data = {"channel": channel,
            "sub_channel": sub_channel,
            "list_name": list_name,
            "author": author,
            "from": from_timestamp,
            "to": to_timestamp,
            "type": "FAST_LIST",
            "action": "update",
            "items": items,
            "reason": reason
            }

    response = await session.post(settings.SO_FRAUD_BACKEND_HOST[project] + '/update_list',
                                  json=data,
                                  timeout=None)

    response.raise_for_status()

    return await response.text()


async def delete_list(project: str,
                      session: aiohttp.ClientSession,
                      channel,
                      sub_channel,
                      list_name,
                      items):
    data = {"channel": channel,
            "sub_channel": sub_channel,
            "list_name": list_name}

    if items:
        data["value"] = items

    response = await session.post(settings.SO_FRAUD_BACKEND_HOST[project] + '/delete_list',
                                  json=data,
                                  timeout=None)

    response.raise_for_status()

    return await response.text()


async def update_counter(project: str,
                         session: aiohttp.ClientSession,
                         channel,
                         sub_channel,
                         key,
                         value):
    data = {"channel": channel,
            "sub_channel": sub_channel,
            "key": key,
            "value": value
            }

    response = await session.post(settings.SO_FRAUD_BACKEND_HOST[project] + '/update_counter',
                                  json=data,
                                  timeout=None)

    response.raise_for_status()

    return await response.text()


@async_to_sync
async def update_counter_sync(project: str,
                              channel,
                              sub_channel,
                              key,
                              value):
    async with backend_session(project) as session:
        return await update_counter(project, session, channel, sub_channel, key, value)


@timed_cache(timedelta(minutes=5))
def login_has_rights(login, pattern):
    user = get_object_or_404(get_user_model(), username=login)
    if user.is_superuser:
        return True

    return user.groups.filter(name__regex=pattern).exists()


def remove_nones(src):
    if isinstance(src,  dict):
        it = iter(src.items())
        dst = {}
        adder = lambda k, v: dst.__setitem__(k, v)
    elif isinstance(src, list):
        it = iter(enumerate(src))
        dst = []
        adder = lambda k, v: dst.append(v)
    else:
        return src

    for k, v in it:
        if v is not None:
            adder(k, remove_nones(v))

    return dst


@async_to_sync
async def make_refund_and_start(purchase_token, uid, reason):
    async with aiohttp.ClientSession() as session:
        response = await session.post(settings.TRUST_BACKEND_HOST + '/trust-payments/v2/refunds',
                                      json={
                                          "purchase_token": purchase_token,
                                          "orders": [{"delta_amount": "11"}],
                                          "reason_desc": reason,
                                      },
                                      headers={
                                          "x-service-token": "passport_d056025fbea3c4700729c5b96b0ff97b",
                                          "X-Uid": str(uid),
                                          "X-User-Ip": "127.0.0.1",
                                      })

        if 400 <= response.status:
            raise Exception(f"{response.status} {await response.text()}")

        response_data = await response.json()
        logger.debug(f"trust-payments/v2/refunds:{response_data}")

        trust_refund_id = response_data["trust_refund_id"]

        response = await session.post(settings.TRUST_BACKEND_HOST + f'/trust-payments/v2/refunds/{trust_refund_id}/start')
        logger.debug(f'/trust-payments/v2/refunds/{trust_refund_id}/start')

        if 400 <= response.status:
            raise Exception(f"{response.status} {await response.text()}")

        return {
            "make_refund_response": response_data,
            "start_refund_response": await response.json(),
        }
