import requests
import os
import random
import re
import logging
import time
import crcmod.predefined

from multiprocessing.pool import ThreadPool

import db_requests

calculate_crc16 = crcmod.predefined.mkCrcFun('crc-16')

def hub_from_env():
    if 'QLOUD_ENVIRONMENT' not in os.environ:
        # Assume local environment.
        return 'http://localhost:14080'
    envtype = os.environ['QLOUD_ENVIRONMENT']
    if envtype == 'production':
        return 'http://xivahub.mail.yandex.net'
    elif envtype == 'corp':
        return 'http://xivahubcorp.mail.yandex.net'
    elif envtype == 'sandbox':
        return 'http://xivahub-sandbox.mail.yandex.net'
    raise Exception('unknown QLOUD_ENVIRONMENT: ' + envtype)

def get(dbname):
    url = hub_from_env() + '/{}/shards_full'.format(dbname)
    response = requests.get(url)
    response.raise_for_status()
    return response.json()['shards']

def get_unique(dbname):
    return _deduplicate(get(dbname))

def is_master(conninfo):
    read_only = db_requests.execute(conninfo + ' connect_timeout=5',
        'SHOW transaction_read_only')[0][0]
    return read_only == 'off'

def replace_hosts(conninfo, hosts_str):
    return re.sub('host=[a-zA-Z0-9\.\-,]+', 'host='+hosts_str, conninfo)

def remove_master_host(conninfo):
    master_hosts = set()
    all_hosts = set(hosts(conninfo, False))
    for host in all_hosts:
        host_conninfo = replace_hosts(conninfo, host)
        try:
            if is_master(host_conninfo):
                master_hosts |= {host}
        except:
            # Do not remove host from conninfo, if it is unaccessible.
            # Otherwise we risk removing all of them because of a network glitch.
            pass
    new_hosts = ','.join(all_hosts - master_hosts)
    return replace_hosts(conninfo, new_hosts)

def conninfo_for_read(shard):
    if len(shard['replicas']) == 0:
        return shard['master']
    return remove_master_host(random.choice(shard['replicas']))

def conninfo_for_write(shard):
    return shard['master']

def admin_conninfo(shard):
    c = conninfo_for_write(shard)
    return re.sub('user=[a-zA-Z0-9_]+', 'user=xiva_admin', c)

def gids_range(shard):
    return [shard['start_gid'], shard['end_gid']]

def hosts(conninfo, pretty=True):
    m = re.search(r'host=(.+?) ', conninfo)
    if m is None:
        m = re.search(r'host=(.+?)$', conninfo)
    comma_separated_fqdns = m.group(1)
    fqdns = re.split(',', comma_separated_fqdns)
    hosts = [re.match(r'(.+?)\.', n).group(1) for n in fqdns] if pretty else fqdns
    return hosts

def dbname(conninfo):
    m = re.search(r'dbname=([a-zA-Z0-9_]+)', conninfo)
    return m.group(1)

def friendly_name(shard):
    return dbname(conninfo_for_write(shard))

def find(uid, shards):
    gid = _gid_from_uid(uid)
    return next(s for s in shards if s['start_gid'] <= gid and gid <= s['end_gid'])

def _gid_from_uid(uid):
    try:
        gid = int(uid)
        if gid < 0:
            raise Exception('negative gid {} for uid {}' % (gid, uid))
        return gid % 65536
    except ValueError:
        return calculate_crc16(uid.encode())

pools = []

class LoggerShardNameAdapter(logging.LoggerAdapter):
    def process(self, msg, kwargs):
        return '%s %s' % (self.extra['shard_name'], msg), kwargs

def _deduplicate(shards_list):
    # Condense all shards having the same conninfo into one.
    shards_by_conninfo = {conninfo_for_write(s): s for s in shards_list}
    return shards_by_conninfo.values()

def _register_pool(p):
    global pools
    pools.append(p)
    return len(pools) - 1

def _get_pool(id):
    global pools
    return pools[id]

def _call_operation(*args):
    (shard, operation, retry_interval, pool_id) = args
    try:
        logger = LoggerShardNameAdapter(logging.getLogger(), {'shard_name': friendly_name(shard)})
        operation(shard, logger)
    except Exception as e:
        logging.exception('exception while processing shard %s, %s: %s',
                      friendly_name(shard), type(e).__name__, str(e))
        time.sleep(retry_interval)
    return (shard, pool_id)

def _callback(args):
    (shard, pool_id) = args
    _get_pool(pool_id).start_shard(shard)

class Pool:
    def __init__(self, shards, operation, retry_interval):
        self.shards = shards
        self.operation = operation
        self.pool = ThreadPool(processes=len(shards))
        self.retry_interval = retry_interval
        self.id = _register_pool(self)

    def start(self):
        for shard in self.shards:
            self.start_shard(shard)

    def start_shard(self, shard):
        logging.debug('enqueue shard %s', friendly_name(shard))
        self.pool.apply_async(_call_operation,
            (shard, self.operation, self.retry_interval, self.id),
            callback=_callback)

    def stop(self):
        self.pool.terminate()
        self.pool.join()
