import sys
import logging
import dateutil.parser

from datetime import datetime
from collections import namedtuple, defaultdict, Counter
from mail.pypg.pypg.common import transaction, qexec
from mail.pypg.pypg.query_conf import load_from_package
from mail.husky.husky.types import Task, ResultData, Errors
from ora2pg.sharpei import get_pg_dsn_from_sharpei
from ora2pg.tools import (
    LidTabMap,
    make_lid_tab_map,
    create_tabs,
    mark_user_as_can_read_tabs,
    make_tabs_list,
)
from pymdb.operations import MoveMessages
from pymdb.queries import Queries
from .base import BaseTask

QUERIES = load_from_package(__package__, __file__)

log = logging.getLogger(__name__)


class MoveMessagesToTabs(BaseTask):
    MIDS_LIMIT = 500

    name = Task.MoveMessagesToTabs
    required_args = []

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.all_tabs = []
        self.inbox_fid = None
        self.lid_tab_map = LidTabMap(None)
        self.tabs_limit = Counter()
        self.tabs_total = Counter()

    @property
    def tabs_mapping(self):
        return self.task_args.get('mapping') or self.config.tabs_mapping

    @property
    def tabs_params(self):
        return self.task_args.get('tabs_limits', [])

    @property
    def final_run(self):
        return self.task_args.get('final', True)

    @property
    def mids_limit(self):
        return self.task_args.get('chunk_size', self.MIDS_LIMIT)

    @property
    def from_date(self):
        date = self.task_args.get('from_date')
        return dateutil.parser.parse(date) if date \
            else datetime.min

    @property
    def force(self):
        return self.task_args.get('force', False)

    @property
    def skip_done(self):
        return self.task_args.get('skip_done', False)

    @property
    def skip_full(self):
        return self.task_args.get('skip_full', False)

    def run(self):
        dsn = self._get_dsn_by_uid()
        with transaction(dsn) as conn:
            queries = Queries(conn, self.uid)
            user = queries.user()
            if not user.is_here:
                return ResultData(
                    transfer_id=self.transfer_id,
                    error=Errors.NotSupported,
                    error_message='Try process user which is not here',
                    task_output=None,
                )
            if self.skip_done and user.can_read_tabs:
                log.info('Will skip user, because skip_done')
                return
            if self.skip_full and self._user_is_fully_done(conn):
                log.info('Will skip user, because skip_full')
                return

        self._make_tabs_map_parameters(dsn)
        self._create_tabs(dsn)

        log.info('Messages in tabs at start: {}'.format(self.tabs_total))
        last_message = self._max_message
        while last_message:
            last_message = self._run_for_newer(dsn, last_message)

        log.info('Messages in tabs after processing newest: {}'.format(self.tabs_total))
        last_message = self._max_message
        while not self._tabs_are_full() and last_message:
            last_message = self._run_for_older(dsn, last_message)

        log.info('Messages in tabs after all: {}'.format(self.tabs_total))
        if self.final_run:
            with transaction(dsn) as conn:
                mark_user_as_can_read_tabs(conn, self.uid)

    def _get_dsn_by_uid(self):
        return get_pg_dsn_from_sharpei(
            sharpei=self.config.sharpei,
            uid=self.uid,
            dsn_suffix=self.config.maildb_dsn_suffix,
        )

    def _create_tabs(self, dsn):
        with transaction(dsn) as conn:
            tabs = create_tabs(conn, self.uid, self.all_tabs)
        if not self.force:
            self.tabs_total = {t.tab: t.message_count for t in tabs}

    def _make_tabs_map_parameters(self, dsn):
        with transaction(dsn) as conn:
            queries = Queries(conn, self.uid)
            labels = queries.labels()
            self.inbox_fid = queries.folder_by_type('inbox').fid

        self.all_tabs = make_tabs_list(self.tabs_mapping)
        for t in self.tabs_mapping:
            self.tabs_limit[t['type']] = int(t.get('min_count', 0))

        self.lid_tab_map = make_lid_tab_map(labels, self.tabs_mapping)
        log.info('Will process user with rules: {}'.format(self.lid_tab_map))

        for t in self.tabs_params:
            self.tabs_limit[t['type']] = int(t.get('min_count', 0))

    def _run_for_newer(self, dsn, last_message):
        with transaction(dsn) as conn:
            messages = self._get_messages(conn, QUERIES.get_messages_newer, last_message)
            self._process_messages(conn, messages, with_total_check=False)
        return None if len(messages) < self.mids_limit else messages[-1]

    def _run_for_older(self, dsn, last_message):
        with transaction(dsn) as conn:
            messages = self._get_messages(conn, QUERIES.get_messages_older, last_message)
            self._process_messages(conn, messages, with_total_check=True)
        return None if len(messages) < self.mids_limit else messages[-1]

    Message = namedtuple('Message', [
        'mid',
        'lids',
        'received_date',
    ])

    def _get_messages(self, conn, query, last_message):
        if not last_message or not self.inbox_fid:
            return []
        cur = qexec(
            conn=conn,
            query=query,
            uid=self.uid,
            fid=self.inbox_fid,
            from_date=self.from_date,
            limit=self.mids_limit,
            last_date=last_message.received_date,
            last_mid=last_message.mid,
            read_all=self.force,
        )
        messages = [self.Message(*row) for row in cur]
        log.info('Found {} messages'.format(len(messages)))
        return messages

    def _process_messages(self, conn, messages, with_total_check):
        to_move = self._split_messages(messages)

        for tab, mids in to_move.items():
            should_move = not with_total_check or not self._tab_is_full(tab)
            if should_move:
                log.info('Will move {} messages to tab "{}"'.format(len(mids), tab))
                MoveMessages(conn, self.uid)(mids, self.inbox_fid, tab)
                self.tabs_total[tab] += len(mids)

    def _split_messages(self, messages):
        res = defaultdict(list)
        for m in messages:
            res[self.lid_tab_map.get_tab(m.lids)].append(m.mid)
        return res

    def _tab_is_full(self, tab):
        return self.tabs_total[tab] >= self.tabs_limit[tab]

    def _tabs_are_full(self):
        for t in self.tabs_mapping:
            if not self._tab_is_full(t['type']):
                return False
        return True

    @property
    def _max_message(self):
        return self.Message(
            mid=sys.maxsize,
            lids=[],
            received_date=datetime.max,
        )

    def _user_is_fully_done(self, conn):
        queries = Queries(conn, self.uid)
        user = queries.user()
        if not user.can_read_tabs:
            return False

        inbox_count = queries.folder_by_type('inbox').message_count
        tabs_count = sum([t.message_count for t in queries.tabs() if t.tab])
        return tabs_count >= inbox_count
