#!/usr/bin/env python2.7
# -*- coding: utf-8 -*-

import logging
import random
import requests
import re
from psycopg2 import connect as pg
from psycopg2 import Error as PgError

from collections import defaultdict, namedtuple


class PgClient(object):
    def __init__(self, group, user, password, db, port):
        self.pg_group = group
        self.get_pg_hosts()
        pg_conf = namedtuple('PgConf', 'user password db port')
        self.pg = pg_conf(user=user, password=password, db=db, port=port)

    def get_pg_hosts(self):
        response = requests.get("https://c.yandex-team.ru/api/groups2hosts/%s" % self.pg_group)
        pg_hosts = defaultdict(list)
        re_host = re.compile('[a-z_\-]+?(\d+)[a-z]\..*?yandex.net')
        for host in response.text.split():
            host_match = re_host.search(host)
            if host_match:
                pg_hosts[int(host_match.group(1))].append(host)
        for shard in pg_hosts:
            random.shuffle(pg_hosts[shard])
        self.pg_hosts = pg_hosts

    def execute_host(self, host, query, result_aggregation, replica_only=False):
        try:
            p_conf = self.pg
            with pg(host=host, port=p_conf.port, user=p_conf.user,
                    password=p_conf.password, dbname=p_conf.db, connect_timeout=3) as conn:

                if replica_only:
                    with conn.cursor() as cur:
                        cur.execute("SELECT pg_is_in_recovery();")
                        is_replica = cur.fetchone()
                    if is_replica and not is_replica[0]:
                        return None

                with conn.cursor() as cur:
                    if isinstance(query, tuple) or isinstance(query, list):
                        cur.execute(*query)
                    else:
                        cur.execute(query)
                    return result_aggregation(cur)

        except PgError as e:
            logging.warning('{}: {}'.format(host, e))
        return None

    def execute_shard(self, shard, query, result_aggregation, replica_only=False):
        for host in self.pg_hosts[shard]:
            result = self.execute_host(host, query, result_aggregation, replica_only=False)
            if result is not None:
                return result
        message = 'Failed to get reply from any of PG host'
        logging.error(message)
        raise EnvironmentError(message)

    def execute(self, query, result_aggregation, shards_aggregation, keep_shard=False, replica_only=False):
        if keep_shard:
            result_by_shard = {shard: self.execute_shard(shard, query, result_aggregation, replica_only=False)
                               for shard in self.pg_hosts}
        else:
            result_by_shard = [self.execute_shard(shard, query, result_aggregation, replica_only=False)
                               for shard in self.pg_hosts]
        return shards_aggregation(result_by_shard)
