#!/usr/bin/env python3
import asyncio
import logging.config
import os
from collections import defaultdict
from enum import unique
from itertools import count
from random import randint
from typing import Optional

import sqlalchemy as sa

from sendr_aiopg import StorageBase, StorageContextBase, create_engine
from sendr_aiopg.action import BaseDBAction
from sendr_aiopg.engine.single import CustomEngine
from sendr_core import BaseCoreContext
from sendr_core.exceptions import CoreFailError
from sendr_taskqueue import (
    BaseActionStorageWorker, BaseAsyncDBAction, BaseStorageArbiterWorker, BaseStorageWorker,
    BaseStorageWorkerApplication, BaseTaskType, BaseWorkerType
)
from sendr_taskqueue.logger import default_logger
from sendr_taskqueue.worker.storage import get_task_mapper, get_worker_mapper
from sendr_utils import copy_context

metadata = sa.MetaData(schema='sendr_qtools')

# App core begin

logging.root.setLevel(logging.INFO)
logging.config.dictConfig({
    'version': 1,
    'disable_existing_loggers': False,
    'formatters': {
        'qloud': {'()': 'sendr_qlog.UniFormatter'},
    },
    'handlers': {
        'default': {'level': 'DEBUG', 'formatter': 'qloud', 'class': 'logging.StreamHandler'}
    },
    'root': {'handlers': ['default'], 'level': 'DEBUG'}
})


@unique
class WorkerType(BaseWorkerType):
    RUN_ACTION = 'run_action'
    WORKER_TYPE_MAPPER = 'mapper'
    WORKER_TYPE_REDUCER = 'reducer'


@unique
class TaskType(BaseTaskType):
    RUN_ACTION = 'run_action'
    TASK_TYPE_MAP = 'map'
    TASK_TYPE_REDUCE = 'reduce'


class ExampleStorage(StorageBase):
    task: get_task_mapper(metadata, TaskType)
    worker: get_worker_mapper(metadata, WorkerType)


class ExampleStorageContext(StorageContextBase):
    STORAGE_CLS = ExampleStorage


class ExampleContext(BaseCoreContext):
    db_engine: CustomEngine
    storage: Optional[ExampleStorage] = None


class ExampleBaseAction(BaseDBAction):
    context = ExampleContext()
    storage_context_cls = ExampleStorageContext


class ExampleAsyncAction(BaseAsyncDBAction, ExampleBaseAction):
    task_type = TaskType.RUN_ACTION


class ExampleStorageWorker(BaseStorageWorker):
    storage_context_cls = ExampleStorageContext

    @copy_context
    async def _run(self):
        ExampleBaseAction.context.request_id = self.request_id
        ExampleBaseAction.context.logger = self.logger
        ExampleBaseAction.context.db_engine = self.app.db_engine
        ExampleBaseAction.context.storage = None
        return await super()._run()


class ArbiterWorker(BaseStorageArbiterWorker):
    storage_context_cls = ExampleStorageContext
    worker_heartbeat_period = 10


# App core end

# Actions begin
amounts = defaultdict(int)
lock = asyncio.Lock()


class ReduceException(CoreFailError):
    pass


class ReduceAction(ExampleBaseAction):
    lucky_bullet = randint(1, 6)

    def __init__(self, letter: str):
        self.letter = letter

    async def handle(self):
        if randint(1, 6) == self.lucky_bullet:
            raise ReduceException('Russian roulette was successfully won')

        global amounts
        async with lock:
            amounts[self.letter] += 1
            stats = sorted(amounts.items(), key=lambda x: x[0])
            new_line = '\n'
            print(f"Stats: \n{new_line.join(f'{x}={y}' for x, y in stats)}")


class ReduceAsyncAction(ExampleAsyncAction, ReduceAction):
    action_name = 'reduce'


class MapAction(ExampleBaseAction):
    def __init__(self, text: str):
        self.text = text

    async def handle(self):
        for i, letter in zip(count(), self.text):
            if i % 2 == 0:
                await ReduceAsyncAction(letter=letter).run_async()
            else:
                await self.storage.task.create(task_type=TaskType.TASK_TYPE_REDUCE, params={'letter': letter})


class MapAsyncAction(ExampleAsyncAction, MapAction):
    action_name = 'map'


# Actions end

class ExampleActionWorker(ExampleStorageWorker, BaseActionStorageWorker):
    task_type = TaskType.RUN_ACTION
    worker_type = WorkerType.RUN_ACTION
    actions = (MapAsyncAction, ReduceAsyncAction)
    retry_exceptions = (ReduceException,)


class ExampleMapperWorker(ExampleStorageWorker):
    worker_type = WorkerType.WORKER_TYPE_MAPPER
    task_action_mapping = {
        TaskType.TASK_TYPE_MAP: MapAction,
    }


class ExampleReducerWorker(ExampleStorageWorker):
    worker_type = WorkerType.WORKER_TYPE_REDUCER
    task_action_mapping = {
        TaskType.TASK_TYPE_REDUCE: ReduceAction,
    }


class ExampleWorkerApplication(BaseStorageWorkerApplication):
    debug = True
    arbiter_cls = ArbiterWorker
    workers = [
        (ExampleActionWorker, 1),
        (ExampleMapperWorker, 1),
        (ExampleReducerWorker, 1),
    ]


def make_db_engine(loop):
    db_configuration = {
        'database': 'sendr_qtools',
        'user': 'sendr_qtools',
        'password': 'P@ssw0rd',
        'host': '127.0.0.1',
        'port': '5442',
        'sslmode': 'disable',
        'connect_timeout': 10,
        'timeout': 5,
        'target_session_attrs': 'read-write',
    }
    return loop.run_until_complete(create_engine(loop=loop, **db_configuration))


def run_workers():
    print(f"{'-' * 5} Run workers {'-' * 5}")
    loop = asyncio.get_event_loop()
    app = ExampleWorkerApplication(db_engine=make_db_engine(loop))
    app.start('::', 8080)


def push_tasks():
    print(f"{'-' * 5} Push tasks {'-' * 5}")
    loop = asyncio.get_event_loop()
    root = os.path.dirname(os.path.abspath(__file__))

    db_engine = make_db_engine(loop)
    conn = loop.run_until_complete(db_engine.acquire())
    storage = ExampleStorage(conn, logger=default_logger)

    ExampleBaseAction.context.request_id = 'push_tasks'
    ExampleBaseAction.context.logger = default_logger
    ExampleBaseAction.context.db_engine = db_engine
    ExampleBaseAction.context.storage = storage

    with open(os.path.join(root, 'text.txt')) as f:
        for i, line in zip(count(), f.readlines()):
            if i % 2 == 0:
                loop.run_until_complete(MapAsyncAction(text=line).run_async())
            else:
                coro = storage.task.create(task_type=TaskType.TASK_TYPE_MAP, params={'text': line})
                loop.run_until_complete(coro)

    loop.run_until_complete(conn.close())
    db_engine.close()
    loop.run_until_complete(db_engine.wait_closed())


if __name__ == '__main__':
    push_tasks()
    run_workers()
