# -*- coding: utf-8 -*-
import argparse
import typing
from collections import namedtuple
from dataclasses import dataclass
from time import sleep

import yt.wrapper
import logging

import requests

LOGGER = logging.getLogger(__name__)

FRAUD_REQUEST_TYPE = "COLLAPSED_AGGRS"

ServicePrefixKey = namedtuple("ServicePrefixKey", ("service", "prefix"))


@yt.wrapper.yt_dataclass
class RequestData:
    id: typing.Optional[str] = None
    service: typing.Optional[str] = None
    prefix: typing.Optional[int] = None
    channel_uri: typing.Optional[str] = None
    key: typing.Optional[str] = None
    key_value: typing.Optional[str] = None
    aggrs: typing.Optional[str] = None
    structured_aggrs: typing.Optional[str] = None
    txn_timestamp: typing.Optional[int] = None


@dataclass
class SendingSettings:
    host: str
    timeout: int
    retries: int
    retry_delay: float


def send_batch(session: requests.Session,
               request_settings: SendingSettings,
               key: ServicePrefixKey,
               batch: typing.List[dict]):
    url = f"{request_settings.host}/update" \
          f"?service={key.service}" \
          f"&prefix={key.prefix}" \
          f"&fraud-request-type={FRAUD_REQUEST_TYPE}" \
          f"&backend=nirvana"

    LOGGER.debug(url)
    LOGGER.debug(batch)

    retries = request_settings.retries
    while retries >= 0:
        retries = retries - 1
        try:
            response = session.post(url, json={
                "AddIfNotExists": True,
                "prefix": key.prefix,
                "docs": batch,
            }, timeout=request_settings.timeout)

            if response.status_code == 529:
                sleep(request_settings.retry_delay)
                continue
        except requests.exceptions.ConnectionError:
            sleep(request_settings.retry_delay)
            continue

        response.raise_for_status()

        break


parser = argparse.ArgumentParser()
parser.add_argument('--proxy', type=str, required=True)
parser.add_argument('--src', type=str, required=True)
parser.add_argument('--batch-size', type=int, required=True)
parser.add_argument('--host', type=str, required=True)
parser.add_argument('--timeout-sec', type=int, required=True, default=1)
parser.add_argument('--retry-delay-sec', type=float, required=True, default=0.5)
parser.add_argument('--retries', type=int, required=True, default=10)


def main():
    args = parser.parse_args()

    request_settings = SendingSettings(args.host, args.timeout_sec, args.retries, args.retry_delay_sec)

    client = yt.wrapper.YtClient(proxy=args.proxy)

    batched_requests = {}

    with requests.session() as session:
        for row in client.read_table_structured(
            yt.wrapper.TablePath(args.src),
            RequestData,
        ):
            row: RequestData = row
            key = ServicePrefixKey(row.service, row.prefix)

            batch = batched_requests.setdefault(key, [])

            batch.append({
                "id": row.id,
                "data": row.aggrs,
                "user_context": row.structured_aggrs,
                "channel_uri": row.channel_uri,
                "key": row.key,
                "key_value": row.key_value,
                "type": FRAUD_REQUEST_TYPE,
                "txn_timestamp": row.txn_timestamp,
            })

            if len(batch) > args.batch_size:
                send_batch(session, request_settings, key, batched_requests.pop(key))

        for key, batch in batched_requests.items():
            send_batch(session, request_settings, key, batch)


if __name__ == "__main__":
    main()
