# -*- coding: utf-8 -*-
"""
Скрипт подготовки локальных PG

1. Скачивает миграции с https://a.yandex-team.ru/arc/trunk/arcadia/disk/admin/salt/pg/salt/components/pg-code/diskdb и
                https://a.yandex-team.ru/arc/trunk/arcadia/disk/admin/salt/pg/salt/components/pg-code/disk_commondb
2. Стопает запущенные PG по указанным портам
3. Удаляет директорию с данными PG
4. Создает новые базы
5. Стартует PG
6. Создает нужных пользователей
7. Накатывает миграции из п.1 с помощью pgmigrate
"""
import argparse
import os
import subprocess
import shutil
import pwd
import grp
import logging
import psycopg2


PG_USER = PG_GROUP = 'postgres'
PG_DATA_DIR = '/data/pg'
SQL_MPFS_USERS = u'''
CREATE USER disk_mpfs WITH SUPERUSER PASSWORD 'diskpasswd';
CREATE USER disk_mpfs_read WITH SUPERUSER PASSWORD 'diskpasswd';
CREATE USER disk_mworker WITH SUPERUSER PASSWORD 'diskpasswd';
CREATE USER disk_mworker_read WITH SUPERUSER PASSWORD 'diskpasswd';
CREATE USER disk_pworker WITH SUPERUSER PASSWORD 'diskpasswd';
CREATE USER disk_pworker_read WITH SUPERUSER PASSWORD 'diskpasswd';
CREATE USER djfs WITH SUPERUSER PASSWORD 'diskpasswd';
CREATE USER djfs_worker WITH SUPERUSER PASSWORD 'diskpasswd';
CREATE USER djfs_albums WITH SUPERUSER PASSWORD 'diskpasswd';
CREATE USER monitor WITH SUPERUSER PASSWORD 'monitorpasswd';
CREATE DATABASE diskdb01 OWNER=disk_mpfs;
CREATE DATABASE disk_commondb01 OWNER=disk_mpfs;
'''

PG_HBA = '''
local all all trust
host all all 0.0.0.0/0 trust
host all all ::0/0 trust
'''


def create_mpfs_users(port):
    for line in SQL_MPFS_USERS.split('\n'):
        if not line:
            continue
        cmd = 'sudo -u postgres psql postgres -p %i -c "%s"' % (port, line)
        _log_cmd(cmd)
        subprocess.check_call(cmd, shell=True)


def prepare_pg_data_dir(data_dir, clean=True):
    if os.path.exists(data_dir) and clean:
        shutil.rmtree(data_dir)

    os.makedirs(data_dir)
    uid = pwd.getpwnam(PG_USER).pw_uid
    gid = grp.getgrnam(PG_GROUP).gr_gid
    os.chown(data_dir, uid, gid)


class PGCTL(object):
    PG_CTL = 'sudo -u postgres /usr/lib/postgresql/9.6/bin/pg_ctl'
    PG_INIT = '%(pg_ctl)s init -o "--encoding=UTF8 --locale=en_US.UTF-8 --lc-collate=C --lc-ctype=C" -D %(pg_data)s'
    PG_START = '%(pg_ctl)s start -o "-i -p %(port)s -N 100 -F -c logging_collector=off" -D %(pg_data)s -w'
    PG_STOP = '%(pg_ctl)s stop -m fast -D %(pg_data)s'
    PG_STATUS = '%(pg_ctl)s status -D %(pg_data)s'

    def __init__(self, pg_data, port):
        self.db_data =  {'pg_ctl': self.PG_CTL, 'pg_data': pg_data, 'port': int(port)}

    def is_alive(self):
        cmd = self.PG_STATUS % self.db_data
        _log_cmd(cmd)
        return subprocess.call(cmd, shell=True) == 0

    def init_db(self):
        cmd = self.PG_INIT % self.db_data
        _log_cmd(cmd)
        subprocess.check_call(cmd, shell=True)
        self._edit_pg_hba()

    def start_server(self):
        cmd = self.PG_START % self.db_data
        _log_cmd(cmd)
        subprocess.check_call(cmd, shell=True)

    def stop_server(self):
        cmd = self.PG_STOP % self.db_data
        _log_cmd(cmd)
        subprocess.check_call(cmd, shell=True)

    def _edit_pg_hba(self):
        pg_hba_path = os.path.join(self.db_data['pg_data'], 'pg_hba.conf')
        with open(pg_hba_path, 'w') as fh:
            fh.write(PG_HBA)


class MigrationHelper(object):
    DB_CONN_STRING_TMPL = "host='localhost' port=%(port)s dbname='%(dbname)s' user='disk_mpfs' password='diskpasswd'"
    DB_NAME = ''
    RUN_MIGRATIONS_TMPL = 'pgmigrate -d %(base_dir)s -t %(target)s -c "%(conn_str)s" -a afterAll:code,afterAll:grants migrate'
    MIGRATIONS_DIR = ''
    MIGRATION_GET_CMD = (
        'svn co '
        'svn+ssh://%(user)s@arcadia.yandex.ru/arc/trunk/arcadia/disk/admin/salt/pg/salt/components/pg-code/%(db)s '
        '%(dir)s'
    )

    def __init__(self, port):
        self._conn_string = self.DB_CONN_STRING_TMPL % {'port': port, 'dbname': self.DB_NAME}
        self.conn = psycopg2.connect(self._conn_string)
        self.cursor = self.conn.cursor()

    @classmethod
    def get_migrations(cls):
        if os.path.exists(cls.MIGRATIONS_DIR):
            shutil.rmtree(cls.MIGRATIONS_DIR)
        migration_get_cmd = cls.MIGRATION_GET_CMD % dict(
            user=os.getenv('SSH_USER', 'robot-disk-cloud'),
            db=cls.DB,
            dir=cls.MIGRATIONS_DIR,
        )
        _log_cmd(migration_get_cmd)
        subprocess.check_call(migration_get_cmd, shell=True)

    def check(self):
        self.cursor.execute('SELECT 1;')

    def run_migration(self, target_version='latest'):
        assert os.path.exists(self.MIGRATIONS_DIR)
        cmd = self.RUN_MIGRATIONS_TMPL % {
            'base_dir': self.MIGRATIONS_DIR,
            'conn_str': self._conn_string,
            'target': target_version,

        }
        _log_cmd(cmd)
        subprocess.check_call(cmd, shell=True)


class DiskDBMigrationHelper(MigrationHelper):
    DB = "diskdb"
    DB_NAME = DB + "01"
    MIGRATIONS_DIR = '/' + DB + '_clone'


class DiskCommonDBMigrationHelper(MigrationHelper):
    DB = "disk_commondb"
    DB_NAME = DB + "01"
    MIGRATIONS_DIR = '/' + DB + '_clone'


def main(start_port=12000, instances_num=2):
    assert instances_num > 0

    migration_helpers = (DiskDBMigrationHelper, DiskCommonDBMigrationHelper)

    for migration_helper in migration_helpers:
        migration_helper.get_migrations()

    for port in range(start_port, start_port + instances_num):
        db_name = str(port)
        data_dir = os.path.join(PG_DATA_DIR, db_name)
        pg_ctl = PGCTL(data_dir, port)
        if pg_ctl.is_alive():
            pg_ctl.stop_server()
        prepare_pg_data_dir(data_dir, clean=True)
        pg_ctl.init_db()
        pg_ctl.start_server()
        assert pg_ctl.is_alive()

        create_mpfs_users(port)

        for migration_helper in migration_helpers:
            mh = migration_helper(port)
            mh.check()
            mh.run_migration()


def _log_cmd(cmd):
    logging.info('Exec: > %s', cmd)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('start_port', type=int)
    parser.add_argument('instances_num', type=int)
    args = parser.parse_args()

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s %(levelname)-8s: %(message)s'
    )
    main(args.start_port, args.instances_num)
