# -*- coding: utf-8 -*-
import datetime
import logging
import os
import requests
import shutil
import six
import sys
import time

import sandbox.common.types.misc as ctm
from sandbox import common
from sandbox import sdk2
from sandbox.projects.answers.common.psql import PostgreSQLStuff, Config
from sandbox.projects.hqcg import HQCG_PG_RESOURCE
from sandbox.sdk2 import parameters
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.sdk2.helpers import subprocess


logger = logging.getLogger(__name__)


def identity(x):
    return x


PG_TO_YT_TYPE_MAPPING = {
    'boolean': ('boolean', bool),
    'character varying': ('string', unicode),
    'integer': ('int64', int),
    'bigint': ('int64', int),
    'timestamp with time zone': ('string', str),
    'timestamp without time zone': ('string', str),
    'text': ('string', unicode),
    'ARRAY': ('any', identity),
    'json': ('any', identity),
    'USER-DEFINED': ('any', identity),
}

REACTOR_API_BASE = 'https://reactor.yandex-team.ru'
REACTOR_ARTIFACT_NAMESPACE_PATH = '/search/talk/db_backup/backup_ready'


class PsqlContextWrapper(object):
    def __init__(self, psql):
        self.psql = psql

    def __enter__(self):
        self.psql.start_local()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.psql.stop_local()


class HqcgDumpDbToYt(sdk2.Task):
    class Requirements(sdk2.Task.Requirements):
        environments = (
            PipEnvironment('yandex-yt'),
            PipEnvironment('yandex-yt-yson-bindings-skynet'),
            PipEnvironment('retry'),
            PipEnvironment('psycopg2-binary'),
        )

    class Parameters(sdk2.Parameters):
        dbname = parameters.String('Database Name', required=True)
        dbhost = parameters.String('Database Host', required=True)
        dbport = parameters.Integer('Database Port', required=True)
        dbuser = parameters.String('Database User', required=True)
        dbpasswd = parameters.String('Database User Password', required=True)
        pg_resource = parameters.Resource('Resource with PSQL', resource_type=HQCG_PG_RESOURCE, required=True)
        ramdrive_size = parameters.Integer('RamDrive size in GB', default=4)
        yt_proxy = parameters.String('Yt proxy', required=True, default='hahn')
        yt_token = parameters.String('Yt Token Secret', required=True)
        yt_base_directory = parameters.String('Yt base directory', required=True)
        reactor_api_token = parameters.String('Reactor API Token Secret')

    def on_enqueue(self):
        sdk2.Task.on_enqueue(self)
        if self.Parameters.ramdrive_size:
            self.Requirements.ramdrive = ctm.RamDrive(
                ctm.RamDriveType.TMPFS,
                int(self.Parameters.ramdrive_size) << 10,
                None
            )

    def _setup_ramdrive(self):
        if self.ramdrive:
            logging.info(
                'Setup RamDrive size: %s path: %s',
                common.utils.size2str(self.ramdrive.size << 20),
                self.ramdrive.path,
            )
            os.chdir(str(self.ramdrive.path))

    @staticmethod
    def _config_psql(workdir, psql_path):
        config = Config(work_dir=workdir)
        psql = PostgreSQLStuff(config, psql_path)
        return psql, config

    def _dump_database(self, psql_path):
        pg_dump_path = os.path.join(psql_path, 'bin/pg_dump')
        dump_path = os.path.join(os.getcwd(), 'hqcg_dump')
        with sdk2.helpers.ProcessLog(self, logging.getLogger('psql_dump')) as pl:
            psql_password = sdk2.Vault.data(self.Parameters.dbpasswd)
            dump = subprocess.Popen(
                [
                    pg_dump_path,
                    '--dbname={}'.format(self.Parameters.dbname),
                    '--host={}'.format(self.Parameters.dbhost),
                    '--port={}'.format(self.Parameters.dbport),
                    '--username={}'.format(self.Parameters.dbuser),
                    '--password',
                    '--format=d',
                    '--schema=public',
                    '--exclude-table={}'.format('repl_mon'),
                    '--file={}'.format(dump_path)
                ],
                stdin=subprocess.PIPE,
                stdout=pl.stdout,
                stderr=pl.stderr,
            )
            dump.communicate(psql_password)
            exitcode = dump.wait()
        if not os.path.exists(dump_path):
            raise Exception('Dump failed with exitcode %s', exitcode)
        return dump_path

    def _restore_pg(self, dump_path, psql_config, psql_path):
        with sdk2.helpers.ProcessLog(self, logging.getLogger('psql_restore')) as pl:
            restore = subprocess.Popen(
                [
                    os.path.join(psql_path, 'bin/pg_restore'),
                    '--dbname={}'.format(psql_config.dbname),
                    '--host={}'.format('localhost'),
                    '--port={}'.format(psql_config.port),
                    '--username={}'.format(psql_config.username),
                    '--password',
                    '--format=d',
                    '--schema=public',
                    '--single-transaction',
                    '--no-owner',
                    '--no-privileges',
                    dump_path
                ],
                stdin=subprocess.PIPE,
                stdout=pl.stdout,
                stderr=pl.stderr,
            )
            restore.communicate(psql_config.password)
            exitcode = restore.wait()
            if exitcode:
                raise Exception('Failed to restore database from dump exitcode: {}'.format(exitcode))

    @staticmethod
    def _get_psql_connection(psql_settings):
        import psycopg2
        connection_string = '''
            dbname='{dbname}'
            user='{user}'
            host='{host}'
            password='{password}'
            port='{port}'
        '''
        connection = psycopg2.connect(
            connection_string.format(
                dbname=psql_settings.dbname,
                user=psql_settings.username,
                host='localhost',
                password=psql_settings.password,
                port=psql_settings.port,
            )
        )
        return connection

    @staticmethod
    def _fetch_tables(connection):
        cur = connection.cursor()
        cur.execute(
            "SELECT table_name FROM information_schema.tables WHERE table_type='BASE TABLE' AND table_schema='public'")
        return [row[0] for row in cur.fetchall()]

    @staticmethod
    def _cursor_iterator(cursor):
        row = cursor.fetchone()
        while row:
            logging.warning('Row: %s', row)
            yield row
            row = cursor.fetchone()

    @staticmethod
    def _get_pg_schema(cur, table):
        from psycopg2 import sql
        cur.execute(sql.SQL('SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s'),
                    [table])
        # PG-schema
        return dict((row[0], row[1]) for row in HqcgDumpDbToYt._cursor_iterator(cur))

    @staticmethod
    def _get_yt_schema(pg_schema):
        return [{"name": key,
                 "type": PG_TO_YT_TYPE_MAPPING[value][0]}
                for (key, value) in pg_schema.items()]

    @staticmethod
    def _get_cast_map(pg_schema):
        return dict((key, PG_TO_YT_TYPE_MAPPING[value][1]) for (key, value) in pg_schema.items())

    @staticmethod
    def _type_caster(obj, cast_map):
        return dict((key, cast_map[key](value) if value is not None else None) for (key, value) in obj.items())

    def _dump_to_yt(self, tables, yt_client, connection):
        dir_path = os.path.join(self.Parameters.yt_base_directory, datetime.datetime.utcnow().isoformat())
        yt_client.create('map_node', dir_path)
        for table in tables:
            self._upload_to_yt(yt_client, connection, dir_path, table)
        latest_link_path = os.path.join(self.Parameters.yt_base_directory, 'latest')
        yt_client.link(dir_path, latest_link_path, force=True)

    def _upload_to_yt(self, yt_client, psql_conn, dst_dir, table):
        from psycopg2 import sql
        with yt_client.Transaction():
            cur = psql_conn.cursor()
            # PG-schema
            table_types = self._get_pg_schema(cur, table)
            # YT-schema
            schema = self._get_yt_schema(table_types)
            # map Field-to-Cast function, e.g. rank: float()
            cast_map = self._get_cast_map(table_types)
            cur.execute(sql.SQL('SELECT ROW_TO_JSON(the_table) FROM (SELECT * FROM {}) the_table').format(sql.Identifier(table)))
            data_gen = (self._type_caster(row[0], cast_map) for row in HqcgDumpDbToYt._cursor_iterator(cur))
            yt_client.create("table", os.path.join(dst_dir, table), recursive=True, attributes={'schema': schema})
            yt_client.write_table(os.path.join(dst_dir, table), data_gen, raw=False)

    def _remove_old_tables(self, yt_client):
        snapshot_dir = self.Parameters.yt_base_directory
        folders = [folder for folder in yt_client.list(snapshot_dir) if folder != 'latest']
        folders.sort()
        if len(folders) > 7:
            for folder in folders[:-7]:
                yt_client.remove(os.path.join(snapshot_dir, folder), recursive=True, force=True)

    def _create_reactor_artifact_instance(self):
        if not self.Parameters.reactor_api_token:
            return
        logging.info('Creating %s artifact instance', REACTOR_ARTIFACT_NAMESPACE_PATH)
        reactor_api_token = sdk2.Vault.data(self.Parameters.reactor_api_token)
        headers = {
            'Authorization': 'OAuth ' + reactor_api_token,
            'Accept': 'application/json',
            'Content-Type': 'application/json',
        }
        url = REACTOR_API_BASE + '/api/v1/a/i/instantiate'
        data = {
            'artifactIdentifier': {
                'namespaceIdentifier': {
                    'namespacePath': REACTOR_ARTIFACT_NAMESPACE_PATH,
                }
            },
            'metadata': {
                '@type': '/yandex.reactor.artifact.EventArtifactValueProto',
            },
            'userTimestamp': datetime.datetime.utcnow().isoformat()
        }
        for attempt in range(3):
            try:
                with requests.post(url, json=data, headers=headers) as r:
                    r.raise_for_status()
                return
            except requests.RequestException as e:
                logging.error('Got RequestError: %s; Attempt: %s', e.message, attempt)
                time.sleep(2 ** attempt)
        six.reraise(*sys.exc_info())

    def on_execute(self):
        self._setup_ramdrive()
        task_path = os.getcwd()

        local_psql_path = str(sdk2.ResourceData(sdk2.Resource[self.Parameters.pg_resource]).path)

        dump_path = self._dump_database(local_psql_path)
        dump_ramdrive_path = os.path.join(os.getcwd(), dump_path)
        os.chdir(task_path)
        if dump_ramdrive_path != os.path.join(os.getcwd(), dump_path):
            shutil.copyfile(dump_ramdrive_path, dump_path)

        psql_workdir = 'psql_data'
        self.psql, psql_config = self._config_psql(psql_workdir, local_psql_path)

        from yt.wrapper import YtClient

        yt_token = sdk2.Vault.data(self.Parameters.yt_token)
        yt = YtClient(proxy=self.Parameters.yt_proxy, token=yt_token)
        if not yt.exists(self.Parameters.yt_base_directory):
            yt.create('map_node', self.Parameters.yt_base_directory)

        with PsqlContextWrapper(self.psql):
            connection = self._get_psql_connection(psql_config)
            self._restore_pg(dump_path, psql_config, local_psql_path)
            tables = self._fetch_tables(connection)
            self._dump_to_yt(tables, yt, connection)
            self._remove_old_tables(yt)
        self._create_reactor_artifact_instance()

    def on_break(self, prev_status, status):
        if hasattr(self, 'psql'):
            self.psql._kill()
