import argparse
import errno
import logging
import os
import psycopg2

from multiprocessing import Pool

from mail.tools.sql_execute_per_shard.lib.get_shards import get_sharpei_response, get_master_dns, get_password_from_yav

log = logging.getLogger(__name__)


def make_parser():
    parser = argparse.ArgumentParser(prog="Run sharpei on all shards",
                                     description="""
    Program can be run both as interactive or non-interactive.
    In case of non-interactive mode, it is recommended to redirect
    standard output to the file, so that it does not mess with the progress
    bar.""")
    parser.add_argument("--sharpei_host", type=str, required=True,
                        help="Sharpei host")
    parser.add_argument("--sharpei_access_machine", type=str, required=True,
                        help="Machine that has access to sharpei via net")
    parser.add_argument("--db_user", type=str, required=True,
                        help="Database user")
    parser.add_argument("--sql_before_all", type=str,
                        help="The sql file to execute on each shard before actual query")
    parser.add_argument("--sql_file_name", type=str, required=True,
                        help="The sql file to execute on each shard")
    parser.add_argument("--result_dir", type=str, default="./result",
                        help="Directory to place results")
    parser.add_argument("--force", action='store_true',
                        help="Rewrite results")
    return parser


def mkdir_p(path, **kwargs):
    try:
        os.makedirs(path, **kwargs)
    except OSError as exc:
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise


def main():
    args = make_parser().parse_args()

    before_query = None
    if args.sql_before_all:
        with open(args.sql_before_all, "r") as file:
            before_query = file.read()

        print(f"""The query to be executed as setup:
            ====
            {before_query}
            ====""")

    with open(args.sql_file_name, "r") as file:
        query = file.read()

    print(f"""The query to be executed:
    ====
    {query}
    ====""")

    response = get_sharpei_response(args.sharpei_host, args.sharpei_access_machine)
    db_password = get_password_from_yav(args.db_user, args.sharpei_host)
    shards = list(get_master_dns(response, args.db_user, db_password))

    mkdir_p(args.result_dir)

    pool = Pool(100)
    proc = ShardProcessor(result_dir=args.result_dir, before_query=before_query, query=query, force=args.force)
    i = -1
    try:
        for i, shard in enumerate(pool.imap_unordered(proc, shards)):
            print(f'#{i:03} Done with shard {shard}')
    finally:
        print(f'Unprocessed shards: {[shard_info[0] for shard_info in shards[i+1:]]}')


class ShardProcessor:
    def __init__(self, result_dir, before_query, query, force):
        self.result_dir = result_dir
        self.before_query = before_query
        self.query = query
        self.force = force

    def __call__(self, shard_item):
        try:
            return process_shard(self.result_dir, self.before_query, self.query, self.force, shard_item[0], shard_item[1])
        except BaseException as e:
            print(f'Failed to run on shard {shard_item}')
            log.exception(e)


def process_shard(result_dir, before_query, query, force, shard, url):
    print(f'Executing of shard {shard} with url {url}...')
    shard_result = os.path.join(result_dir, str(shard))
    if not force and os.path.exists(shard_result):
        print(f'Shard {shard} is already processed')
        return shard
    with open(shard_result, 'w') as fd:
        connect = psycopg2.connect(url)

        if before_query:
            print(f'Run setup on {shard}')
            cursor = connect.cursor()
            cursor.execute(before_query)
            connect.commit()

        print(f'Run query on {shard}')
        cursor = connect.cursor()
        cursor.copy_expert(f"COPY ({query}) TO STDOUT WITH CSV HEADER DELIMITER E'\\t'", fd)
        connect.commit()

        connect.close()
    return shard


if __name__ == '__main__':
    main()
