import argparse
import backoff
import requests
import psycopg2
import psycopg2.extras
import hashlib
import logging
import ujson as json
import re
import os
from dataclasses import dataclass
from multiprocessing import Pool


UPDATE_BY_UIDS = 'update_by_uids'
UPDATE_FOR_ALL = 'update_for_all'
RUN_TASK = 'run_task'
INIT = 'init'
UNLOCK = 'unlock'
LOAD_UIDS_FOR_UPDATE = 'load_uids'
COUNT_UIDS_FOR_UPDATE = 'count_uids_for_update'
SETTINGS_INIT_FILE = '/etc/settings/init_settings.json'
FILE_UIDS_PATH = '/etc/settings/uids'


logging.basicConfig(
    filename="settings_modification.log",
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='[%Y/%m/%d:%H:%M:%S]'
)


@dataclass
class Task:
    name: str


@dataclass
class CreateTaskResult:
    status: bool


@dataclass
class ModificationResult:
    count: int


class Queries:

    CREATE_TASK_Q = """SELECT code.create_task_bulk_modify(
                        %(i_task)s,
                        %(i_modify_type)s,
                        %(i_settings)s
                    ) AS status"""

    INSERT_TASK_Q = """INSERT INTO settings.bulk_modification
                           (id, name, create_date)
                       VALUES
                           (1, '', NOW())
                       ON CONFLICT DO NOTHING"""

    UNLOCK_Q = """UPDATE settings.bulk_modification
                         SET name = ''
                       WHERE id = 1"""

    MODIFY_Q = """SELECT code.bulk_modify_settings() AS count"""

    COUNT_UIDS_FOR_UPDATE_Q = """SELECT count(uid)
                                   FROM settings.users_for_update
                                  WHERE NOT is_modified"""


def make_parser():

    parser = argparse.ArgumentParser(prog="Modification settings on all shards",
                                     description="")

    parser.add_argument("--operation", type=str, required=True,
                        help="modification operation",
                        choices=[UPDATE_BY_UIDS, UPDATE_FOR_ALL, INIT, UNLOCK, RUN_TASK, LOAD_UIDS_FOR_UPDATE, COUNT_UIDS_FOR_UPDATE],)
    parser.add_argument("--shards", type=str, required=False,
                        help="list of shards for modification, format: shard1,shard2,... or all", default="all")
    parser.add_argument("--setting", type=str, required=False,
                        help='setting for modification, format: {"setting1":"value1", "setting2"...}')
    parser.add_argument("--sharpei_host", type=str, required=True,
                        help="sharpei host")
    parser.add_argument("--user", type=str, required=True,
                        help="user for sql connect")
    parser.add_argument("--file_name", type=str, required=False,
                        help="file with shards name")
    return parser


@backoff.on_exception(backoff.expo, (requests.HTTPError, requests.exceptions.Timeout), max_time=1, max_tries=2)
def sharpei_stat(location):
    return requests.get(
        location + '/v2/stat',
        timeout=1,
    ).json()


def get_master_dsn(sharpei_response, user):
    for _, shard in sharpei_response.items():
        name = shard['name']
        for database in shard['databases']:
            if database["role"] == "master":
                db = database["address"]
                url = f"host={db['host']} dbname={db['dbname']} user={user} port={db['port']}"
                yield name, url


def filter_dsn(dsns, shards):
    if not shards:
        return dsns
    return {name: dsn for (name, dsn) in dsns.items() if name in shards}


def get_settings(settings_args):
    try:
        settings = json.loads(settings_args)
        for setting, value in settings.items():
            if not re.match(r"[a-z0-9_]+", setting) or not re.match(r"[a-z0-9_]*", value):
                return None
        return json.dumps(dict(single_settings=settings))
    except Exception as e:
        logging.exception(e)
        return None


def get_shards_name_from_file(file_name_args):
    with open(file_name_args, 'r') as file:
        return [line.strip() for line in file]


def get_shards_name_from_args(shards_args):
    if shards_args == 'all':
        return []
    return [shard.strip() for shard in shards_args.split(",")]


def get_init_settings():
    with open(SETTINGS_INIT_FILE, 'r') as file:
        settings = json.load(file)
        if 'single_settings' not in settings or len(settings['single_settings']) == 0:
            return None
        return json.dumps(settings)


@backoff.on_exception(backoff.expo, psycopg2.DatabaseError, max_time=1, max_tries=2)
def run_select(connect, qurey, result_type, **kwargs):
    with connect.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
        cursor.execute(qurey, kwargs)
        return [result_type(**row) for row in cursor.fetchall()]


@backoff.on_exception(backoff.expo, psycopg2.DatabaseError, max_time=1, max_tries=2)
def run_query(connect, qurey, **kwargs):
    with connect.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
        cursor.execute(qurey, kwargs)
        return cursor.rowcount


@backoff.on_exception(backoff.expo, psycopg2.DatabaseError, max_time=1, max_tries=2)
def copy_uids_to_users_for_update(connect, file):
    with connect.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
        cursor.copy_from(file, 'settings.users_for_update', columns=('uid',))


class Worker:
    def __init__(self, operation, setting):
        h = hashlib.md5(operation.encode('utf-8') + setting.encode('utf-8'))
        self.name = h.hexdigest()
        self.setting = setting
        self.operation = operation

    def __call__(self, shard):
        logging.info(f"start operation {self.operation} {self.name} on {shard[0]}")
        try:
            with psycopg2.connect(shard[1]) as conn:
                conn.set_session(autocommit=True)
                if self.operation == UNLOCK:
                    run_query(conn, Queries.UNLOCK_Q)
                elif self.operation == LOAD_UIDS_FOR_UPDATE:
                    with open(os.path.join(FILE_UIDS_PATH, shard[0]), 'r') as file:
                        copy_uids_to_users_for_update(conn, file)
                elif self.operation == COUNT_UIDS_FOR_UPDATE:
                    result = run_select(conn, Queries.COUNT_UIDS_FOR_UPDATE_Q, ModificationResult)
                    logging.info(f"result {result[0].count} of {self.operation} on {shard[0]}")
                elif self.operation == RUN_TASK:
                    result = run_select(conn, Queries.MODIFY_Q, ModificationResult)
                    while len(result) > 0 and result[0].count > 0:
                        result = run_select(conn, Queries.MODIFY_Q, ModificationResult)
                elif self.prologue(conn):
                    pass
                else:
                    return [False, shard[0]]
                logging.info(f"end operation {self.operation} {self.name} on {shard[0]}")
                return [True, shard[0]]
        except Exception as e:
            logging.error(f'failed start task on shard with name {shard[0]}')
            logging.exception(e)
            return [False, shard[0]]

    def prologue(self, conn):
        run_query(conn, Queries.INSERT_TASK_Q)
        result = run_select(
            conn, Queries.CREATE_TASK_Q,
            CreateTaskResult,
            i_task=self.name,
            i_modify_type=self.operation,
            i_settings=self.setting,
        )
        if len(result) == 0 or not result[0].status:
            logging.error(f"not create task for {self.name} operation")
            return False
        return True


def main():
    args = make_parser().parse_args()
    setting = ''

    if args.operation == UPDATE_FOR_ALL or args.operation == UPDATE_BY_UIDS:
        setting = get_settings(args.setting) if args.setting else None
        if setting is None:
            print('format for settings must: {"setting1":"value1", "setting2"...}')
            exit(1)

    if args.operation == INIT:
        setting = get_init_settings()
        if setting is None:
            print('wrong init settings format')
            exit(1)

    pool = Pool(10)

    try:
        shards_for_filter = (
            get_shards_name_from_file(args.file_name)
            if args.file_name
            else get_shards_name_from_args(args.shards)
        )

        dsns = {name: dsn for (name, dsn) in get_master_dsn(sharpei_stat(f'http://{args.sharpei_host}'), args.user)}
        worker = Worker(args.operation, setting)
        for i, [status, shard] in enumerate(pool.imap_unordered(worker, filter_dsn(dsns, shards_for_filter).items())):
            print(f'#{i:03} Done on shard {shard} with status {status}')
    except Exception as e:
        logging.exception(e)
