import aiopg
import psycopg2
from psycopg2.extras import RealDictCursor
from datetime import timedelta, datetime
from dataclasses import dataclass
from abc import ABCMeta, abstractmethod, abstractproperty
from collections import defaultdict

from .helpers import chunks
from mail.shiva.stages.api.props.services.sharpei import get_shard_dsn
from .task import TaskParams
from mail.shiva.stages.api.props.logger import get_uid_logger

log = get_uid_logger(__name__)


@dataclass
class ArchivationRule:
    uid: str = None
    fid: str = None
    keep_days: int = None
    max_size: int = None


@dataclass
class Message:
    mid: str = None
    received_date: datetime = None


async def get_messages_by_date(conn, folder):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            SELECT mid, received_date
              FROM mail.box
             WHERE uid = %(uid)s
               AND fid = %(fid)s
               AND received_date < now() - %(days)s
             ORDER BY received_date
            ''',
            dict(
                uid=folder.uid,
                fid=folder.fid,
                days=timedelta(folder.keep_days)
            )
        )
        return [Message(**r) async for r in cur]


async def get_folder_messages_count(conn, folder):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            SELECT message_count
              FROM mail.folders
             WHERE uid = %(uid)s
               AND fid = %(fid)s
            ''',
            dict(
                uid=folder.uid,
                fid=folder.fid
            )
        )
        return (await cur.fetchone())['message_count']


async def get_messages_by_count(conn, folder):
    message_count = await get_folder_messages_count(conn, folder)
    if message_count > folder.max_size:
        async with conn.cursor() as cur:
            count = message_count - folder.max_size
            await cur.execute(
                '''
                SELECT mid, received_date
                  FROM mail.box
                 WHERE uid = %(uid)s
                   AND fid = %(fid)s
                 ORDER BY received_date LIMIT %(count)s
                ''',
                dict(
                    uid=folder.uid,
                    fid=folder.fid,
                    count=count
                )
            )
            return [Message(**r) async for r in cur]
    return []


class RuleApplicator(object):
    __metaclass__ = ABCMeta

    def __init__(self, conn, get_messages=get_messages_by_date):
        self.conn = conn
        self._get_messages = get_messages

    @abstractproperty
    def rule_type(self):
        pass

    async def get_folders(self):
        async with self.conn.cursor() as cur:
            await cur.execute(
                '''
                SELECT uid, fid, keep_days, max_size
                  FROM mail.folder_archivation_rules
                 WHERE archive_type = %(rule_type)s::mail.folder_archivation_type
                ''',
                dict(rule_type=self.rule_type)
            )
            return [ArchivationRule(**r) async for r in cur]

    async def get_messages(self, folder):
        return await self._get_messages(self.conn, folder)

    @abstractmethod
    async def process_messages(self, folder, messages):
        pass

    async def __call__(self, chunk_size):
        for folder in await self.get_folders():
            all_messages = await self.get_messages(folder)
            for m in chunks(all_messages, chunk_size):
                await self.process_messages(folder, list(m))


class CleanApplicator(RuleApplicator):
    rule_type = 'clean'

    @staticmethod
    async def delete_mids(conn, uid, mids):
        async with conn.cursor() as cur:
            await cur.execute(
                'SELECT * FROM code.delete_messages(%(uid)s, %(mids)s)',
                dict(
                    uid=uid,
                    mids=mids,
                )
            )

    async def process_messages(self, folder, messages):
        try:
            await self.delete_mids(self.conn, folder.uid, [m.mid for m in messages])
        except psycopg2.DatabaseError:
            log.exception('exception while deleting messages', uid=folder.uid)
        else:
            log.info(f'successfully deleted {len(messages)} messages', uid=folder.uid)


class ArchiveApplicator(RuleApplicator):
    rule_type = 'archive'

    @staticmethod
    def split_mids_by_year(messages):
        mids_by_year = defaultdict(list)
        for m in messages:
            mids_by_year[str(m.received_date.year)].append(m.mid)
        return mids_by_year

    @staticmethod
    async def shared_folders(conn, uid):
        async with conn.cursor() as cur:
            await cur.execute(
                '''
                SELECT fid
                  FROM mail.shared_folders
                 WHERE uid = %(uid)s
                ''',
                dict(uid=uid)
            )
            return [r['fid'] async for r in cur]

    @staticmethod
    async def create_folder(conn, uid, name, parent_fid):
        async with conn.cursor() as cur:
            await cur.execute(
                '''SELECT fid FROM code.get_or_create_folder(%(uid)s, %(name)s, %(parent_fid)s, 'user')''',
                dict(
                    uid=uid,
                    name=name,
                    parent_fid=parent_fid
                )
            )
            return (await cur.fetchone())['fid']

    @staticmethod
    async def add_fid_to_shared(conn, uid, fid):
        async with conn.cursor() as cur:
            await cur.execute(
                '''SELECT * FROM code.add_folder_to_shared_folders(%(uid)s, %(fid)s)''',
                dict(
                    uid=uid,
                    fid=fid
                )
            )

    @staticmethod
    async def move_mids(conn, uid, mids, dest_fid):
        async with conn.cursor() as cur:
            await cur.execute(
                '''SELECT * FROM code.move_messages(%(uid)s, %(mids)s, %(dest_fid)s)''',
                dict(
                    uid=uid,
                    mids=mids,
                    dest_fid=dest_fid
                )
            )

    async def process_messages(self, folder, messages):
        mids_by_year = self.split_mids_by_year(messages)
        created_fids = {}
        try:
            for year in mids_by_year:
                fid = await self.create_folder(self.conn, folder.uid, year, folder.fid)
                shared_fids = await self.shared_folders(self.conn, folder.uid)
                if folder.fid in shared_fids and fid not in shared_fids:
                    await self.add_fid_to_shared(self.conn, folder.uid, fid)
                created_fids[year] = fid
            for year, midlist in mids_by_year.items():
                await self.move_mids(self.conn, folder.uid, midlist, created_fids[year])
        except psycopg2.DatabaseError:
            log.exception('exception while moving messages', uid=folder.uid)
        else:
            log.info(f'successfully moved {len(messages)} messages', uid=folder.uid)


@dataclass
class FolderArchivationParams(TaskParams):
    task_name: str = 'folder_archivation'
    chunk_size: int = 1000


async def shard_folder_archivation(params: FolderArchivationParams, stats):
    async with aiopg.connect(await get_shard_dsn(params.sharpei, params.db_user, params.shard_id, stats), cursor_factory=RealDictCursor) as conn:
        await CleanApplicator(conn, get_messages_by_count)(params.chunk_size)
        await CleanApplicator(conn, get_messages_by_date)(params.chunk_size)
        await ArchiveApplicator(conn, get_messages_by_date)(params.chunk_size)
        await ArchiveApplicator(conn, get_messages_by_count)(params.chunk_size)
