import logging
import asyncio
import tenacity
import ssl
import aiopg
import psycopg2
from psycopg2.extras import RealDictCursor
from aiohttp import ClientTimeout
from dataclasses import dataclass
from datetime import timedelta, datetime

from mail.python.theatre.detail.tvm import TvmServiceTickets

from .helpers import chunks
from .export_helper import randomize_start_time
from mail.shiva.stages.api.props.services.sharpei import get_shard_dsn
from .task import TaskParams
from mail.python.theatre.profiling.http import ProfiledClientSession
from mail.shiva.stages.api.settings.log import http_logger

log = logging.getLogger(__name__)

HELLO_SUID = '66466005'
HELLO_UID = 'mail:30391'
IMMUTABLE_UNIT = '102'
SHARED_SUID = '0'
SHARED_UID = 'mail:0'
MDS_ONLY_UNIT_ID = '320'

FILTER_IMAP_TIMEOUT = 600


@dataclass
class DeletedMail:
    qid: int = None
    deleted_date: datetime = None
    st_id: str = None
    fails_count: int = None


async def get_queue(conn, purge_ttl, chunk_size, jobs_count, job_no):
    last_qid = 0
    last_deleted_date = datetime.fromtimestamp(0)
    while True:
        async with conn.cursor() as cur:
            await cur.execute(
                '''
                    SELECT qid, deleted_date, st_id, COALESCE(fails_count, 0) as fails_count
                      FROM mail.storage_delete_queue q1
                     WHERE q1.deleted_date < now() - %(delay)s
                       AND (deleted_date, qid) > (%(last_deleted_date)s, %(last_qid)s)
                       AND q1.qid %% %(jobs_count)s = %(job_no)s
                       AND NOT EXISTS (
                       SELECT 1
                         FROM mail.storage_delete_queue q2
                        WHERE q2.uid = q1.uid
                          AND q2.st_id = q1.st_id
                          AND q2.deleted_date >= now() - %(delay)s
                          AND q2.qid != q1.qid
                        )
                     ORDER BY deleted_date
                     LIMIT %(chunk_size)s
                ''',
                dict(
                    delay=purge_ttl,
                    chunk_size=chunk_size,
                    last_qid=last_qid,
                    last_deleted_date=last_deleted_date,
                    jobs_count=jobs_count,
                    job_no=job_no,
                )
            )

            chunk = [DeletedMail(**r) async for r in cur]

            if chunk:
                last_qid = chunk[-1].qid
                last_deleted_date = chunk[-1].deleted_date
                yield chunk

            if len(chunk) < chunk_size:
                return


async def get_welcome(conn):
    async with conn.cursor() as cur:
        await cur.execute('SELECT st_id FROM code.welcome_mails()')
        return [r['st_id'] async for r in cur]


async def filter_imap_copied(conn, deleted_mails):
    async with conn.cursor(timeout=FILTER_IMAP_TIMEOUT) as cur:
        await cur.execute(
            '''
            SELECT st_id
              FROM mail.messages
             WHERE hashtext(st_id) IN (
                SELECT hashtext(ist)
                  FROM unnest(%(stids)s::text[]) AS ist)
               AND st_id = ANY(%(stids)s::text[])
            ''',
            dict(stids=[mail.st_id for mail in deleted_mails])
        )
        stids_copies = set([r['st_id'] async for r in cur])
        return [mail for mail in deleted_mails if mail.st_id not in stids_copies]


def is_hello_user(stid_parts):
    return stid_parts[1] == HELLO_SUID \
        or stid_parts[1] == HELLO_UID


def is_immutable_unit(stid_parts):
    return stid_parts[0] == IMMUTABLE_UNIT


def greeting(stid_parts):
    return is_hello_user(stid_parts) \
        or is_immutable_unit(stid_parts)


def shared_delivery(stid_parts):
    return stid_parts[1] == SHARED_SUID \
        or stid_parts[1] == SHARED_UID


def is_mulca_only(stid_parts):
    unit_id = stid_parts[0]
    unique_id = stid_parts[2]
    return unit_id != MDS_ONLY_UNIT_ID \
        and not (unique_id[0] == 'E' and ':' in unique_id)


def can_delete(st_id):
    stid_parts = st_id.split('.')
    if len(stid_parts) >= 3:
        return not greeting(stid_parts) \
            and not shared_delivery(stid_parts) \
            and not is_mulca_only(stid_parts)

    log.debug("invalid stid %s", st_id)
    return False


class StorageEraser:
    def __init__(self, mgate_host, mgate_port, cafile, http_session, mds_id, tvm, max_sleep):
        self.mulcagate = f'{mgate_host}:{mgate_port}'
        self.params = {'service': 'maildb'}
        self.http_session = http_session
        self.tvm = tvm
        self.mds_id = mds_id
        self.ssl = ssl.create_default_context(cafile=cafile)
        self.max_sleep = max_sleep
        self.need_sleep = False

    def build_url(self, st_id):
        return f'{self.mulcagate}/gate/del/{st_id}'

    @tenacity.retry(reraise=True, wait=tenacity.wait_fixed(1), stop=tenacity.stop_after_attempt(5))
    async def get_tvm(self):
        return await self.tvm.get(self.mds_id)

    async def erase_mail(self, mail, headers):
        try:
            async with self.http_session.get(
                    url=self.build_url(mail.st_id),
                    params=self.params,
                    headers=headers,
                    ssl=self.ssl
            ) as response:
                if 200 <= response.status < 300:
                    self.storage_deleted.append(mail)
                else:
                    reason = (await response.text()).strip()
                    log.error(f'non-200 storage code: {response.status}, reason: {reason}, st_id: {mail.st_id}')
                    if response.status == 429:
                        self.need_sleep = True
                    else:
                        self.storage_errors.append(mail)
        except Exception as exp:
            log.exception(exp)
            self.storage_errors.append(mail)

    async def erase(self, mails):
        if self.need_sleep:
            await randomize_start_time(max_delay=self.max_sleep)
            self.need_sleep = False
        self.storage_deleted = []
        self.storage_errors = []
        headers = {}
        if self.tvm and self.mds_id:
            headers['X-Ya-Service-Ticket'] = await self.get_tvm()
        tasks = [asyncio.ensure_future(self.erase_mail(mail, headers)) for mail in mails]
        await asyncio.gather(*tasks)
        return self.storage_deleted, self.storage_errors


def split_rows_by_deleteable(del_rows, welcome):
    deleteable_rows = []
    ignore_rows = []
    for r in del_rows:
        if can_delete(r.st_id) and r.st_id not in welcome:
            deleteable_rows.append(r)
        else:
            ignore_rows.append(r)
    return deleteable_rows, ignore_rows


async def split_rows_by_backup(conn, del_rows):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            SELECT st_id
              FROM backup.box
             WHERE hashtext(st_id) IN (
                SELECT hashtext(ist)
                  FROM unnest(%(stids)s::text[]) AS ist)
               AND st_id = ANY(%(stids)s::text[])
            ''',
            dict(stids=[mail.st_id for mail in del_rows])
        )
        backuped_stids = set([r['st_id'] async for r in cur])
        for_erase = [mail for mail in del_rows if mail.st_id not in backuped_stids]
        backuped = [mail for mail in del_rows if mail.st_id in backuped_stids]
        return for_erase, backuped


async def delete_from_storage_delete_queue(conn, storage_deleted):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            DELETE FROM mail.storage_delete_queue
             WHERE qid = ANY(%(qids)s::bigint[])
            ''',
            dict(qids=[r.qid for r in storage_deleted])
        )


async def update_storage_delete_queue_fails_count(conn, storage_errors):
    async with conn.cursor() as cur:
        await cur.execute(
            '''
            UPDATE mail.storage_delete_queue
               SET fails_count = COALESCE(fails_count, 0) + 1,
                   deleted_date = deleted_date + '1 day'::interval
             WHERE qid = ANY(%(qids)s::bigint[])
            ''',
            dict(qids=[r.qid for r in storage_errors])
        )


async def purge_storage(conn, welcome, del_rows, storage_eraser, max_fails_count):
    deletable_rows, ignore_rows = split_rows_by_deleteable(del_rows, welcome)
    try:
        deletable_rows_without_copies = await filter_imap_copied(conn, deletable_rows)
        rows_for_erase, backuped_rows = await split_rows_by_backup(conn, deletable_rows_without_copies)
        storage_deleted, storage_errors = await storage_eraser.erase(rows_for_erase)
        rows_to_delete = []
        rows_to_update = []
        for row in storage_errors:
            if row.fails_count < max_fails_count:
                rows_to_update.append(row)
            else:
                rows_to_delete.append(row)
        rows_to_delete += (storage_deleted + ignore_rows + backuped_rows)
        if rows_to_delete:
            await delete_from_storage_delete_queue(conn, rows_to_delete)
        if rows_to_update:
            await update_storage_delete_queue_fails_count(conn, rows_to_update)
    except (psycopg2.DataError,
            psycopg2.IntegrityError,
            psycopg2.InternalError):
        log.exception("pg error while storage-deleting")
    else:
        log.info(
            "successfully purged %d messages, "
            "ignored %d immutable messages, "
            "ignored %d has-copies messages, "
            "ignored %d backuped messages, "
            "updated %d failed messages, "
            "droped %d failed messages",
            len(storage_deleted),
            len(ignore_rows),
            len(deletable_rows) - len(deletable_rows_without_copies),
            len(backuped_rows),
            len(rows_to_update),
            len(rows_to_delete),
        )


async def purge_storage_chunk(conn, welcome, queue_chunk, storage_eraser, delete_chunk_size, max_fails_count):
    for del_rows_chunk in chunks(queue_chunk, delete_chunk_size):
        await purge_storage(
            conn=conn,
            welcome=welcome,
            del_rows=del_rows_chunk,
            storage_eraser=storage_eraser,
            max_fails_count=max_fails_count,
        )


@dataclass
class PurgeStorageParams(TaskParams):
    task_name: str = 'purge_storage'
    ca_cert_path: str = None
    mgate_host: str = None
    mgate_port: int = None
    mgate_total_timeout: int = 5
    tvm: TvmServiceTickets = None
    mds_id: int = None
    purge_ttl: timedelta = timedelta(days=7)
    stids_in_memory: int = 5000
    delete_chunk_size: int = 100
    max_fails_count: int = 10
    max_delay: int = 20 * 60
    max_sleep: int = 5 * 60


async def shard_purge_storage(params: PurgeStorageParams, stats):
    await randomize_start_time(max_delay=params.max_delay)

    async with aiopg.connect(await get_shard_dsn(params.sharpei, params.db_user, params.shard_id, stats), cursor_factory=RealDictCursor) as conn:
        async with ProfiledClientSession(metrics=stats, logger=http_logger.get_logger(), timeout=ClientTimeout(total=params.mgate_total_timeout)) as http_session:
            welcome = await get_welcome(conn)
            storage_eraser = StorageEraser(
                mgate_host=params.mgate_host,
                mgate_port=params.mgate_port,
                http_session=http_session,
                cafile=params.ca_cert_path,
                mds_id=params.mds_id,
                tvm=params.tvm,
                max_sleep=params.max_sleep,
            )
            async for queue_chunk in get_queue(
                conn=conn,
                purge_ttl=params.purge_ttl,
                chunk_size=params.stids_in_memory,
                jobs_count=params.jobs_count,
                job_no=params.job_no,
            ):
                await purge_storage_chunk(
                    conn=conn,
                    welcome=welcome,
                    queue_chunk=queue_chunk,
                    storage_eraser=storage_eraser,
                    delete_chunk_size=params.delete_chunk_size,
                    max_fails_count=params.max_fails_count,
                )
