import asyncio
import datetime
import logging
import typing
from concurrent.futures import ThreadPoolExecutor

import yt.wrapper as yt
from sqlalchemy import update

from .mail_id import MailId
from .models import Engine, Report, Table, Signature, ChoicesHistory, Choice
from .reports_group import ReportsSignatureGroup

logger = logging.getLogger("yt")


class ReportsProducer:
    def __init__(self, root: str, engine: Engine):
        while root[-1] == '/':
            root = root[0:-1]
        assert yt.exists(root)

        self._root = root
        self.engine = engine

        self.pool = ThreadPoolExecutor(1)

    def __delitem__(self, sig: Signature):
        update(Report) \
            .where(Report.shingle_type == sig.shingle_type and
                   Report.shingle == sig.shingle and
                   Report.domain_hash == sig.domain_hash) \
            .values(proceed=True)

    async def fetch(self, refresh_time):
        while True:
            await asyncio.get_running_loop().run_in_executor(self.pool, self._fetch)
            await asyncio.sleep(refresh_time)

    def _fetch(self):
        session = self.engine.session()
        try:
            logger.info("fetching")
            reports_dir = yt.ypath_join(self._root, "reports")

            proceed_reports_ids = {report.id for report in
                                   session.query(Report).filter(Report.proceed == 1)}

            proceed_tables = {table.table for table in
                              session.query(Table).filter(Table.proceed == 1)}
            reports_tables = set(yt.list(reports_dir)) - proceed_tables

            logger.debug(f"proceed_tables:{proceed_tables}")
            logger.debug(f"reports_tables:{reports_tables}")
            logger.debug(f"proceed_reports_ids:{proceed_reports_ids}")

            for table in reports_tables:
                table_model = Table(table=table)
                if session.query(Table).get(table) is None:
                    session.add(table_model)

                path = yt.ypath_join(self._root, "mail_ids_flow", table)
                logger.info(f"fetching ids from {path}")

                try:
                    mail_ids = {rec["sig_id"]: MailId.parse(rec["mail_id"]) for rec in
                                yt.read_table(path, enable_read_parallel=True)}
                except Exception as e:
                    logger.error(str(e))
                    mail_ids = {}
                path = yt.ypath_join(reports_dir, table)
                logger.info(f"fetching reports from {path}")
                reports_records = yt.read_table(path, enable_read_parallel=True)

                for record in reports_records:
                    report = Report.parse(table=table, record=record, mail_id=mail_ids.get(record["sig_id"]))
                    if report is None or report.id in proceed_reports_ids:
                        continue

                    session.add(report)

                table_model.proceed = True

                session.commit()
                logger.debug(f"loaded {table_model}")
        except Exception as e:
            logger.exception(e)
            session.rollback()

    @staticmethod
    def check_stable_choice(choices_history: typing.List[ChoicesHistory]) -> typing.Optional[Choice]:
        if not choices_history:
            return None
        bans = 0
        falses = 0

        for choice in choices_history:
            if choice.date + datetime.timedelta(hours=choice.duration) < datetime.datetime.now():
                continue

            if choice.choice == Choice.BAN:
                bans += 1
            elif choice.choice == Choice.FALSE:
                falses += 1
        if bans // 2 > falses:
            return Choice.BAN
        if falses // 2 > bans:
            return Choice.FALSE

    def load_false_choices(self) -> typing.List[ChoicesHistory]:
        session = self.engine.session()

        return session.query(ChoicesHistory).filter(ChoicesHistory.choice == Choice.FALSE).all()

    def load_reports(self) -> typing.Dict[Signature, ReportsSignatureGroup]:
        session = self.engine.session()
        logger.info("start loading groups")

        groups: typing.Dict[Signature, ReportsSignatureGroup] = {}
        for report in session.query(Report).filter(Report.proceed == 0):
            choices_history = self.engine.get_choices_history(report.signature)
            if ReportsProducer.check_stable_choice(choices_history) is None:
                groups.setdefault(report.signature, ReportsSignatureGroup(report.signature)).add(report)
            else:
                del self[report.signature]

        logger.info(f"load {len(groups)} groups")

        return groups
