import click
import psycopg2

from mail.tools.dbaas.helpers.types.env import Envs, Env
from mail.tools.dbaas.helpers.yav import get_users_from_yav

import logging

log = logging.getLogger(__name__)


def get_shards(conn):
    cur = conn.cursor()
    cur.execute('SELECT shard_id, name FROM shards.shards')
    return [(row[0], row[1]) for row in cur.fetchall()]


def get_deleted_users(conn, shard_id):
    cur = conn.cursor()
    cur.execute(
        '''
        SELECT uid FROM shards.deleted_users
         WHERE shard_id = %(shard_id)s
        ''',
        vars=dict(
            shard_id=shard_id,
        )
    )
    return [row[0] for row in cur.fetchall()]


def check_shard_deleted_users(shard_dsn, uids):
    with psycopg2.connect(shard_dsn) as conn:
        cur = conn.cursor()
        cur.execute(
            '''
            SELECT uid FROM mail.users as mu
             WHERE uid = ANY(%(uids)s::bigint[])
               AND is_here
               AND NOT is_deleted
               AND NOT EXISTS (select 1 from mail.box where uid = mu.uid)
            ''',
            vars=dict(
                uids=uids,
            )
        )
        for row in cur.fetchall():
            print(f'Bad is_deleted for uid {row[0]}')


def fix_shard_deleted_users(shard_dsn, uids):
    with psycopg2.connect(shard_dsn) as conn:
        cur = conn.cursor()
        cur.execute(
            '''
            UPDATE mail.users as mu
               SET is_deleted = true, purge_date = now() + '200 days'::interval
             WHERE uid = ANY(%(uids)s::bigint[])
               AND is_here
               AND NOT is_deleted
               AND NOT EXISTS (select 1 from mail.box where uid = mu.uid)
            ''',
            vars=dict(
                uids=uids,
            )
        )
        if cur.rowcount:
            log.info(f'Fixed {cur.rowcount} users')


@click.command('fix-transferred-deleted-users')
@click.option('--env', 'env_name', default=Envs.prod.value.name)
@click.option('--db_user', default='transfer')
@click.option('--only-check', is_flag=True)
def main(env_name: str, db_user: str, only_check: bool = False):
    env: Env = Envs[env_name].value
    users = get_users_from_yav(env.users_file)

    with psycopg2.connect(env.sharpei_dsn) as sharpei_conn:
        for shard_id, shard_name in get_shards(sharpei_conn):
            print(f'*** Shard {shard_name} with shard_id {shard_id}')
            try:
                uids = get_deleted_users(sharpei_conn, shard_id)
                shard_dsn = get_shard_master_dsn(sharpei_conn, shard_id, db_user, users[db_user]['password'])
                if only_check:
                    check_shard_deleted_users(shard_dsn, uids)
                else:
                    fix_shard_deleted_users(shard_dsn, uids)
            except Exception as e:
                print(f'*** Bad Shard {shard_name} {e}')


def get_shard_master_dsn(conn, shard_id: int, user: str, passwd: str):
    cur = conn.cursor()
    cur.execute(
        'select host from shards.instances where shard_id = %(shard_id)s',
        vars=dict(shard_id=shard_id),
    )
    hosts = [row[0] for row in cur.fetchall()]

    return ' '.join((
        f'host={",".join(hosts)}',
        'port=6432',
        'dbname=maildb',
        f'user={user}',
        f'password={passwd}',
        'target_session_attrs=read-write',
    ))


if __name__ == '__main__':
    main()
