#!/usr/bin/env python
# coding: utf-8
import json
import logging
from collections import defaultdict
from typing import List, Iterable, Dict  # noqa

from mail.pypg.pypg.common import qexec
from mail.pypg.pypg.query_conf import load_from_package
from ora2pg.tools import http
from ora2pg.tools.find_master_helpers import get_sharddb_pooled_conn
from psycopg2 import IntegrityError
from psycopg2.errorcodes import UNIQUE_VIOLATION

log = logging.getLogger(__name__)

QUERIES = load_from_package(__package__, __file__)


def get_shard_id(uid, dsn):
    with get_sharddb_pooled_conn(dsn, autocommit=False) as conn:
        cur = qexec(conn, QUERIES.get_shard_id, uid=uid)
        row = cur.fetchone()
        if row is None:
            return None
        return row[0]


class ShardInfo(object):
    def __init__(self, shard_id, is_deleted):
        self.shard_id = shard_id
        self.is_deleted = is_deleted


def get_shard_info(uid, dsn):
    with get_sharddb_pooled_conn(dsn, autocommit=False) as conn:
        cur = qexec(conn, QUERIES.get_shard_id, uid=uid)
        row = cur.fetchone()
        if row is None:
            cur = qexec(conn, QUERIES.get_deleted_shard_id, uid=uid)
            row = cur.fetchone()
            if row:
                return ShardInfo(row[0], True)
            else:
                return ShardInfo(None, True)
        return ShardInfo(row[0], False)


class SharpeiError(RuntimeError):
    pass


class UserAlreadyInited(SharpeiError):
    pass


class NoAliveDatabase(SharpeiError):
    pass


def init_in_sharpei(uid, dsn, allow_inited, shard_id=None):
    needUpdateShardId = False
    with get_sharddb_pooled_conn(dsn, autocommit=False) as conn:
        if shard_id is None:
            cur = qexec(conn, QUERIES.generate_shard_id, uid=uid)
            row = cur.fetchone()
            assert row, \
                'code.generade_shard_id({0}) failed to give shard_id!'.format(
                    uid)
            shard_id = row[0]
        try:
            qexec(
                conn,
                QUERIES.save_user_in_shard,
                uid=uid,
                shard_id=shard_id
            )
        except IntegrityError as exc:
            if exc.pgcode != UNIQUE_VIOLATION:
                raise
            conn.rollback()
            err_str = 'User {0} is initialized in shards.users with wrong shard_id'.format(uid)
            if not allow_inited:
                raise UserAlreadyInited(err_str)
            needUpdateShardId = True
    if needUpdateShardId:
        log.warning(err_str + ', changing shard_id to {0}'.format(shard_id))
        update_shard_id(dsn, uid, shard_id)
    return shard_id


def add_deleted_user_to_sharpei(uid, dsn, shard_id):
    needUpdateShardId = False
    with get_sharddb_pooled_conn(dsn, autocommit=False) as conn:
        try:
            qexec(
                conn,
                QUERIES.save_deleted_user_in_shard,
                uid=uid,
                shard_id=shard_id
            )
        except IntegrityError as exc:
            if exc.pgcode != UNIQUE_VIOLATION:
                raise
            conn.rollback()
            needUpdateShardId = True
    if needUpdateShardId:
        err_str = 'User {0} is initialized in shards.deleted_users with wrong shard_id'.format(uid)
        log.warning(err_str + ', changing shard_id to {0}'.format(shard_id))
        update_shard_id(dsn, uid, shard_id, is_deleted=True)
    return shard_id


def update_shard_id(dsn, uid, shard_id, is_deleted=False):
    with get_sharddb_pooled_conn(dsn, autocommit=False) as conn:
        if not is_deleted:
            qexec(
                conn,
                QUERIES.update_shard_id,
                uid=uid,
                shard_id=shard_id)
        else:
            qexec(
                conn,
                QUERIES.update_deleted_shard_id,
                uid=uid,
                shard_id=shard_id)


def remove_user_from_sharpei(dsn, uid):
    with get_sharddb_pooled_conn(dsn, autocommit=False) as conn:
        cur = qexec(
            conn,
            QUERIES.remove_user_from_shards,
            uid=uid
        )
        if not cur.rowcount:
            log.warning(
                'No rows removed after %s, uid: %r',
                QUERIES.remove_user_from_shards,
                uid
            )


def http_request(url, do_retries=False):
    skip_retry_codes = [400, 404]
    return http.request(url=url, do_retries=do_retries, skip_retry_codes=skip_retry_codes)


def get_connstring_by_id(sharpei, shard_id, dsn_suffix):
    from json import load
    with http_request(
        http.url_join(
            host=sharpei,
            method='stat'
        )
    ) as fd:
        statresp = load(fd)
        if str(shard_id) not in statresp:
            raise RuntimeError(
                'shard_id=%r not found in /stat response: %r' % (shard_id, statresp)
            )
        addr = next((host['address'] for host in statresp[str(shard_id)]
                     if host['role'] == 'master'), None)
        if not addr:
            raise RuntimeError(
                'master host not found: %r' % statresp[str(shard_id)]
            )
        addr['dsn_suffix'] = dsn_suffix
        return 'host={host} port={port} dbname={dbname} {dsn_suffix}'.format(
            **addr
        )


def get_all_connstrings(sharpei, dsn_suffix):
    from json import load
    connstrings = []
    with http_request(
        http.url_join(
            host=sharpei,
            method='stat'
        )
    ) as fd:
        statresp = load(fd)
        for shard in statresp:
            addr = next((host['address'] for host in shard
                         if host['role'] == 'master'), None)
            if not addr:
                raise RuntimeError(
                    'master host not found: %r' % shard
                )
            addr['dsn_suffix'] = dsn_suffix
            connstrings.append('host={host} port={port} dbname={dbname} {dsn_suffix}'.format(
                **addr
            ))
    return connstrings


def get_pg_dsn_from_sharpei(sharpei, uid, dsn_suffix):
    with http_request(
        http.url_join(
            host=sharpei,
            method='conninfo',
            args=dict(
                uid=str(uid),
                force='true',
                mode='master',
                format='json')),
        do_retries=True
    ) as fd:
        resp = json.load(fd)
        try:
            master_address = resp['addrs'][0]
        except (KeyError, IndexError) as exc:
            raise NoAliveDatabase('no alive database in sharpei response: %r: %s' % (resp, exc))
        pg_dsn = 'host={host} port={port} dbname={dbname} {dsn_suffix}'.format(
            dsn_suffix=(dsn_suffix or ''),
            **master_address
        )
        log.info(
            'User uid={0} belongs to pg shard with dsn="{1}"'.format(
                uid,
                pg_dsn
            )
        )
        return pg_dsn


def get_shard_name(sharpei, shard_id):
    with http_request(url=http.url_join(host=sharpei, method='v3/stat'), do_retries=True) as fd:
        return json.load(fd)[shard_id]['name']


def group_by_shards(sharpei,                # type: str
                    maildb_dsn_suffix,      # type: str
                    objects_with_uid,       # type: Iterable[object]
                    uid_getter=lambda o: o):
    '''return @objects_with_uid groupped by shards'''
    dsn2objects = defaultdict(list)  # type: Dict[str, List[object]]
    for obj in objects_with_uid:
        dsn = get_pg_dsn_from_sharpei(
            sharpei=sharpei,
            uid=uid_getter(obj),
            dsn_suffix=maildb_dsn_suffix
        )
        dsn2objects[dsn].append(obj)
    return dsn2objects


def get_shard_by_uid(sharpei, uid, do_retries=True):
    with http_request(
            url=http.url_join(
                host=sharpei,
                method='conninfo',
                args=dict(
                    uid=str(uid),
                    mode='all',
                )),
            do_retries=do_retries
    ) as fd:
        return json.load(fd)


def get_shard_by_deleted_uid(sharpei, uid, do_retries=True):
    with http_request(
            url=http.url_join(
                host=sharpei,
                method='deleted_conninfo',
                args=dict(
                    uid=str(uid),
                    mode='all',
                )),
            do_retries=do_retries
    ) as fd:
        return json.load(fd)


def get_shard_id_by_name(sharpei, shard_name):
    with http_request(http.url_join(host=sharpei, method='v3/stat'), do_retries=True) as fd:
        shards = json.load(fd)
        for shard_id in shards:
            if shards[shard_id]['name'] == shard_name:
                return shard_id
        return None


def get_shard_by_id(sharpei, shard_id):
    with http_request(
            http.url_join(
                host=sharpei,
                method='stat'
            )
    ) as fd:
        statresp = json.load(fd)
        if str(shard_id) not in statresp:
            raise RuntimeError('shard_id=%r not found in /stat response: %r' % (shard_id, statresp))
        addr = next((host['address'] for host in statresp[str(shard_id)] if host['role'] == 'master'), None)
        if not addr:
            raise RuntimeError('master host not found: %r' % statresp[str(shard_id)])
        return addr


def get_shard_id_by_host(sharpei, host):
    with http_request(
            http.url_join(
                host=sharpei,
                method='stat'
            )
    ) as fd:
        statresp = json.load(fd)
        for shard_id in statresp:
            for shard_host in statresp[shard_id]:
                if shard_host['address']['host'] == host:
                    return shard_id
        return None
