#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Check if we are async replica
"""

import argparse
import binascii
import datetime
import logging
import logging.handlers
import os
import socket
import time
import traceback
from multiprocessing import TimeoutError
from multiprocessing.pool import ThreadPool

import psycopg2
import yt.wrapper as yt
from psycopg2.extras import RealDictCursor

from kazoo.client import KazooClient

try:
    from configparser import ConfigParser
except ImportError:
    from ConfigParser import ConfigParser


def get_config(path):
    """
    Parse config from path
    """
    config = ConfigParser()
    config.read(path)
    return config


def get_logger(config):
    """
    Initialize logger
    """
    log = logging.getLogger('main')
    log.setLevel(logging.DEBUG)
    log_handler = logging.handlers.RotatingFileHandler(
        config.get('main', 'log_path'),
        mode='a',
        maxBytes=10 * 1024 * 1024,
        backupCount=4,
        encoding=None,
        delay=0)
    log_handler.setFormatter(
        logging.Formatter('%(asctime)s [%(levelname)s]: %(message)s'))
    log.addHandler(log_handler)

    return log, log_handler


def get_pg_info(log, config):
    """
    Get status of local postgresql host
    """
    info = {
        'alive': False,
        'replica': False,
    }
    try:
        conn = psycopg2.connect('dbname=postgres')
        cur = conn.cursor()
        cur.execute('SELECT pg_is_in_recovery()')
        info['alive'] = True
        if cur.fetchone()[0]:
            info['replica'] = True
            cur.execute('SELECT extract(epoch FROM '
                        '(current_timestamp - ts)) FROM repl_mon')
            if cur.fetchone()[0] > config.getint('main',
                                                 'max_replication_lag'):
                info['alive'] = False
    except Exception as exc:
        log.error('PostgreSQL is dead: %s', repr(exc))

    return info


def get_pgsync_role(zk_client, hostname, config):
    """
    Get role of hostname using pgsync
    """
    leader_lock = zk_client.Lock(
        os.path.join(config.get('main', 'pgsync_prefix'), 'leader'))
    sync_lock = zk_client.Lock(
        os.path.join(config.get('main', 'pgsync_prefix'), 'sync_replica'))
    leader_contenders = leader_lock.contenders()
    sync_contenders = sync_lock.contenders()
    if leader_contenders and leader_contenders[0] == hostname:
        return 'master'
    elif sync_contenders and sync_contenders[0] == hostname:
        return 'sync'
    return 'async'


def election(log, config, pg_info):
    """
    Participate in zk-based election
    """
    try:
        deadline = time.time() + config.getint('main', 'election_timeout')
        log.info('Starting election with deadline: %s', deadline)
        hostname = socket.getfqdn()
        zk_client = KazooClient(config.get('main', 'zk_hosts'))
        zk_client.start()
        election_lock = zk_client.Lock(
            config.get('main', 'election_lock'), hostname)
        role = get_pgsync_role(zk_client, hostname, config)
        log.info('Local host role is %s', role)
        while time.time() < deadline:
            contenders = election_lock.contenders()
            if not contenders or (role == 'async' and
                                  hostname not in contenders):
                log.info('Acquiring election lock')
                election_lock.acquire(timeout=10)
            else:
                if len(contenders) > 1 and contenders[0] == hostname:
                    if role in ('master', 'sync'):
                        log.info('We are %s: releasing election lock', role)
                        election_lock.release()
                        contenders = contenders[1:]
            time.sleep(1)
        if contenders[0] == hostname:
            while len(contenders) != 1:
                log.info('Waiting for other nodes to stop election: %s',
                         ', '.join(contenders[1:]))
                time.sleep(1)
            return True
        return False
    except Exception as exc:
        log.error('election error: %s', repr(exc))
        return False


def to_json_compatible(record):
    for k, v in record.items():
        if isinstance(v, datetime.datetime):
            record[k] = v.isoformat()
        elif isinstance(v, datetime.date):
            record[k] = v.isoformat()
        elif k == 'sha256_sum':
            record[k] = binascii.hexlify(bytearray(v))
    return record


def do_dump_table(log, dbname, table_prefix, table):
    pool = ThreadPool(processes=1)
    table, pk_name = table.split(':')
    table_name = table_prefix + '/' + table
    if not yt.exists(table_name):
        yt.create("table", table_name)
        log.debug('create new yt table for %s' % table)

    limit = 200000
    conn = psycopg2.connect('dbname=%s port=5432' % dbname)
    log.debug('success on connect to postgresql')
    cur = conn.cursor(cursor_factory=RealDictCursor)
    previous_pk = None
    offset_file_name = '/tmp/%s.pk_offset' % table
    try:
        last_pk = open(offset_file_name).read().strip()
    except Exception:
        last_pk = None

    res = None
    while True:
        try:
            if last_pk is None:
                cur.execute('SELECT * FROM %s ORDER BY %s LIMIT %s' % (table, pk_name, limit))
            else:
                cur.execute(
                    "SELECT * FROM %s WHERE %s > '%s' ORDER BY %s LIMIT %s"
                    % (table, pk_name, last_pk, pk_name, limit)
                )
            batch_rows = cur.fetchall()
            log.debug('success on fetch %s after %s' % (table, last_pk))
            if not batch_rows:
                break
            try:
                data = map(to_json_compatible, batch_rows)
                log.debug('ready to send %d rows' % len(data))
                if res is not None:
                    res.get(timeout=60)
                res = pool.apply_async(
                    yt.write_table,
                    (
                        (yt.TablePath(table_name, append=True)),
                        data,
                    ),
                    dict(
                        format=yt.JsonFormat(attributes={"encode_utf8": False}),
                        raw=False,
                    )
                )
                log.debug('send async yt write %s after %s' % (table, last_pk))
            except (yt.errors.YtHttpResponseError, TimeoutError) as e:
                log.error(e.__class__)
                log.error(e)
                last_pk = previous_pk
                continue

            previous_pk = last_pk
            last_pk = batch_rows[-1][pk_name]
            with open(offset_file_name, 'w') as offset_file:
                offset_file.write(str(last_pk))
        except Exception as e:
            log.error(e.__class__)
            log.error(e)
            traceback.print_exc()
            time.sleep(5)

    if res is not None:
        res.get(timeout=600)

    pool.close()
    pool.join()

def do_dump(log, config):
    """
    Dump tables
    """
    dbname = config.get('main', 'dbname').strip()
    tables = config.get('main', 'tables').split(',')
    table_prefix = (
        config.get('main', 'table_prefix') +
        time.strftime("%Y-%m-%d") +
        '/' +
        config.get('main', 'rs_name')
    )
    yt.update_config({'proxy': {'url': 'hahn.yt.yandex.net'}, 'token': config.get('main', 'yt_token')})
    if not yt.exists(table_prefix):
        yt.create("map_node", table_prefix, recursive=True)
    log.info('Starting dump')
    for table in tables:
        do_dump_table(log, dbname, table_prefix, table)


def main():
    """
    Console entry-point
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-c',
        '--config',
        type=str,
        default='/etc/dump_yt.conf',
        help='config path')
    args = parser.parse_args()
    config = get_config(args.config)
    log, log_handler = get_logger(config)

    try:
        info = get_pg_info(log, config)
        if not info['alive']:
            log.info('Not participating in election')
            return
        if not election(log, config, info):
            log.info('Election lost. Skipping dump')
            return
        do_dump(log, config)
    except Exception as e:
        log.exception('Unable to make dump: %s' % e)
        raise


if __name__ == '__main__':
    main()
