import os
import time
import asyncio
import logging
import argparse

import aiohttp
import aiohttp.web
from ujson import loads, dumps

from infra.dproxy.proto import dproxy_pb2


log = logging.getLogger()


class UnistatServer:
    def __init__(self, port: int):
        self.port = port
        self.app = aiohttp.web.Application()
        self.app.add_routes([
            aiohttp.web.get('/', self.ping),
            aiohttp.web.get('/unistat', self.unistat),
        ])
        self.runner = None
        self.metrics = []

    async def start(self):
        if self.runner is not None:
            return

        self.runner = aiohttp.web.AppRunner(self.app)
        await self.runner.setup()
        server = aiohttp.web.TCPSite(self.runner, '::', self.port)
        await server.start()

    async def shutdown(self):
        if self.runner is not None:
            await self.runner.cleanup()

    async def ping(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
        return aiohttp.web.Response(status=200, body=b'Ok', content_type='text/plain')

    async def unistat(self, _: aiohttp.web.Request) -> aiohttp.web.Response:
        return aiohttp.web.json_response(self.metrics, dumps=dumps)


async def iteration(args, token, http, headers, last_mtime):
    with open(args.marker_file) as f:
        st = os.fstat(f.fileno())
        if last_mtime is not None and last_mtime == st.st_mtime:
            log.info("last_mtime not changed, logs check not needed")
            return last_mtime, None

        last_mtime = st.st_mtime
        validation = loads(f.read().strip())

    log.debug("Got new validation info: %s", validation)

    query = dproxy_pb2.SearchLogEntriesRequest()
    query.project_id = args.project
    query.deploy_unit_id = args.deploy_unit
    query.stage_id = args.stage
    query.box_list.values.append(args.box)
    query.pod_list.values.append(args.pod)
    uf = query.user_field_list.add()
    uf.path = 'validating'
    uf.values.append(validation['marker'])
    query.timestamp_range.begin.FromMicroseconds(int(validation['start_timestamp'] * 1e6))
    query.order = dproxy_pb2.ASC
    query.limit = 50

    start_timestamp = time.time()
    first_record_timestamp = None
    end_timestamp = None
    last_record_timestamp = None
    reordering = False
    numbers = []

    retries = 0
    deadline = start_timestamp + args.max_delay_time

    while True:
        if retries == args.retries:
            break

        if time.time() > deadline:
            break

        try:
            async with http.post(args.dproxy, data=query.SerializeToString(), headers=headers) as resp:
                resp.raise_for_status()
                result = dproxy_pb2.SearchLogEntriesResponse()
                result.MergeFromString(await resp.read())

            for entry in result.log_entries:
                if first_record_timestamp is None:
                    first_record_timestamp = entry.timestamp.ToNanoseconds() / 1e9
                last_record_timestamp = entry.timestamp.ToNanoseconds() / 1e9

                ctx = loads(entry.context)
                numbers.append(ctx['counter'])

            if len(numbers) == validation['records']:
                break

            if len(result.log_entries):
                query.continuation_token = result.continuation_tokens.forward

            if len(result.log_entries) == query.limit:
                retries = 0
                continue

        except Exception as e:
            log.exception("chunk retrieval failed (attempt %d), will retry: %s", retries + 1, e)
            await asyncio.sleep(5.)
        else:
            # data not ready yet, so sleep a little
            log.debug("chunk not ready (attempt #%d), will retry", retries + 1)
            await asyncio.sleep(1.)

        retries += 1

    end_timestamp = time.time()
    log.info("retrieval finished")

    for idx, number in enumerate(numbers, 1):
        if idx < number:
            reordering = True
            break

    stats = {
        'start_timestamp_ammm': start_timestamp,
        'start_processing_delay_ammm': start_timestamp - validation['end_timestamp'],
        'end_timestamp_ammm': end_timestamp,
        'end_processing_delay_ammm': end_timestamp - validation['end_timestamp'],
        'processing_time_ammm': end_timestamp - start_timestamp,
        'records_written_ammm': validation['records'],
        'records_read_ammm': len(numbers),
        'missing_ammm': validation['records'] - len(set(numbers)),
        'reordering_ammm': int(reordering),
        'duplicates_ammm': len(numbers) - len(set(numbers)),
    }
    if numbers:
        stats.update({
            'first_record_timestamp_ammm': first_record_timestamp,
            'first_record_timestamp_delay_ammm': first_record_timestamp - validation['start_timestamp'],
            'last_record_timestamp_ammm': last_record_timestamp,
            'last_record_timestamp_delay_ammm': last_record_timestamp - validation['end_timestamp'],
        })

    return last_mtime, stats


async def async_main(args, token):
    timeout = aiohttp.ClientTimeout(total=90)
    headers = {
        'Authorization': f'Bearer {token}',
        'Accept': 'application/x-protobuf',
        'Content-Type': 'application/x-protobuf',
    }
    last_mtime = None

    unistat_server = UnistatServer(port=args.unistat_port)
    await unistat_server.start()
    log.info("HTTP server started on port %d", args.unistat_port)

    try:
        async with aiohttp.ClientSession(timeout=timeout) as http:
            while True:
                try:
                    last_mtime, data = await iteration(args, token, http, headers, last_mtime)
                except Exception as e:
                    log.exception("iteration failed: %s", e)
                    pass
                else:
                    if data is not None:
                        unistat_server.metrics = list(data.items())
                        log.info("unistat metrics updated")

                await asyncio.sleep(1)
    finally:
        await unistat_server.shutdown()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--unistat-port', type=int, default=80)
    parser.add_argument('--marker-file', default='validating-marker.txt')
    parser.add_argument('--dproxy', default='https://ui-proxy.deploy.yandex-team.ru/api/logs/SearchLogEntries/')
    parser.add_argument('--project', default=os.getenv('DEPLOY_PROJECT_ID'))
    parser.add_argument('--stage', default=os.getenv('DEPLOY_STAGE_ID'))
    parser.add_argument('--deploy-unit', default=os.getenv('DEPLOY_UNIT_ID'))
    parser.add_argument('--box', default=os.getenv('DEPLOY_BOX_ID'))
    parser.add_argument('--pod', default=os.getenv('DEPLOY_POD_ID'))
    parser.add_argument('--retries', type=int, default=10)
    parser.add_argument('--max-delay-time', type=int, default=300)

    args = parser.parse_args()
    logging.basicConfig(format='[%(asctime)s.%(msecs)003d]  %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)

    token = os.getenv('DPROXY_OAUTH_TOKEN')
    if not token:
        raise ValueError("DPROXY_OAUTH_TOKEN is empty")

    loop = asyncio.get_event_loop()
    loop.run_until_complete(async_main(args, token))
