# -*- coding: utf-8 -*-

from sandbox import common
import sandbox.common.types.misc as ctm
import datetime
import logging
import os
import sandbox.common.types.resource as ctr
from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.sdk2 import parameters
from sandbox.sdk2.helpers import subprocess
from sandbox.projects.answers import resources
from sandbox.projects.answers.common.encrypt_mixin import EncryptMixin, GpgSettings
from sandbox.projects.answers.common.psql import PostgreSQLStuff, Config


PG_TO_YT_TYPE_MAPPING = {
    'boolean': 'boolean',
    'character varying': 'string',
    'character': 'string',
    'double precision': 'double',
    'integer': 'int64',
    'bigint': 'int64',
    'smallint': 'int64',
    'timestamp without time zone': 'string',
    'timestamp with time zone': 'string',
    'date': 'string',
    'text': 'string',
    'uuid': 'string',
    'ARRAY': 'any',
    'json': 'any',
    'jsonb': 'any',
    'USER-DEFINED': 'any',
    'bytea': 'string',
    'tsvector': 'any',
}


class PsqlContextWrapper(object):
    def __init__(self, psql):
        self.psql = psql

    def __enter__(self):
        self.psql.start_local()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.psql.stop_local()


def get_pg_connection(psql_settings):
    import psycopg2
    connection_string = '''
        dbname='{dbname}'
        user='{user}'
        host='{host}'
        password='{password}'
        port='{port}'
    '''
    connection = psycopg2.connect(
        connection_string.format(
            dbname=psql_settings.dbname,
            user=psql_settings.username,
            host='localhost',
            password=psql_settings.password,
            port=psql_settings.port,
        )
    )
    return connection


class UploadJob(object):

    def __init__(self, pg_table, pg_config, yt_table, yt_token, yt_proxy):
        self.pg_table = pg_table
        self.pg_config = pg_config
        self.yt_table = yt_table
        self.yt_token = yt_token
        self.yt_proxy = yt_proxy

    def upload(self):
        from psycopg2 import sql
        from yt.wrapper import YtClient, JsonFormat

        yt_client = YtClient(proxy=self.yt_proxy, token=self.yt_token)
        pg_client = get_pg_connection(self.pg_config)

        cursor = pg_client.cursor()
        yt_schema = self.get_yt_schema(cursor, self.pg_table)
        yt_client.create("table", self.yt_table, recursive=True, attributes={'schema': yt_schema})

        cursor.execute(
            sql.SQL(
                'SELECT TO_JSON(t.*)::text FROM {} t'
            ).format(
                sql.Identifier(self.pg_table)
            )
        )
        yt_client.write_table(
            self.yt_table,
            (row for row, in cursor),
            raw=True,
            format=JsonFormat(attributes={
                'encode_utf8': False,
                'enable_integral_to_double_conversion': True,
                'plain': True,
            })
        )

    def get_yt_schema(self, cur, table):
        from psycopg2 import sql
        cur.execute(
            sql.SQL(
                """
                SELECT column_name, data_type
                FROM information_schema.columns
                WHERE table_name = %s
                """
            ),
            [table]
        )
        yt_schema = [
            {"name": key, "type": PG_TO_YT_TYPE_MAPPING[value]}
            for key, value in cur
        ]
        return yt_schema


class AnswersDumpToYt(sdk2.Task, EncryptMixin):

    skip_tables = {}

    class Requirements(sdk2.Task.Requirements):
        environments = (
            PipEnvironment('yandex-yt'),
            PipEnvironment('yandex-yt-yson-bindings-skynet'),
            PipEnvironment('retry'),
            PipEnvironment('psycopg2-binary'),
        )

    class Parameters(sdk2.Parameters):

        use_last_dump = parameters.Bool(
            'Use last database dump to copy on YT',
            required=True,
            default=True,
        )
        with use_last_dump.value[False]:
            psql_dump = parameters.Resource(
                'Dump of Answers PSQL',
                resource_type=resources.AnswersPostgresqlDump,
                required=True,
            )
        postgres_resource = parameters.Resource(
            'Resource with PSQL',
            resource_type=resources.AnswersPostgresql,
            required=True,
        )
        ramdrive_size = parameters.Integer('RamDrive size in GB', default=4)
        yt_base_directory = parameters.String(
            'Yt base directory',
            required=True,
        )
        yt_proxy = parameters.String(
            'Yt proxy',
            required=True,
            default='banach',
        )
        yt_token = parameters.String(
            'Answers Yt Token Secret',
            required=True,
        )
        gpg_key_owner = parameters.String(
            'Gpg Key Owner',
            required=True,
        )
        env = parameters.String(
            'Database environment',
            choices=[
                ('dev', resources.Environments.DEV),
                ('prod', resources.Environments.PROD),
                ('prestable', resources.Environments.PRESTABLE),
            ],
            required=True,
        )
        left_only_n_snapshots = parameters.Integer(
            'Left only N snapshots, remove rest (0 = unlimited)',
            default=0
        )

    def on_enqueue(self):
        sdk2.Task.on_enqueue(self)
        if self.Parameters.ramdrive_size:
            self.Requirements.ramdrive = ctm.RamDrive(
                ctm.RamDriveType.TMPFS,
                int(self.Parameters.ramdrive_size) << 10,
                None
            )

    def setup_ramdrive(self):
        if self.ramdrive:
            logging.info(
                'Setup RamDrive size: %s path: %s',
                common.utils.size2str(self.ramdrive.size << 20),
                self.ramdrive.path,
            )
            os.chdir(str(self.ramdrive.path))

    def config_psql(self, workdir, psql_path):
        config = Config(work_dir=workdir)
        psql = PostgreSQLStuff(config, psql_path)
        return psql, config

    def fetch_tables(self, connection):
        cur = connection.cursor()
        cur.execute(
            '''
                SELECT table_name
                FROM information_schema.tables
                WHERE table_type='BASE TABLE'
                AND table_schema='public'
            '''
        )
        rows = [row[0] for row in cur.fetchall()]
        return rows

    def restore_pg(self, dump_path, psql_config, psql_path):
        with sdk2.helpers.ProcessLog(
                self, logging.getLogger('psql_restore')
        ) as pl:
            restore = subprocess.Popen(
                [
                    os.path.join(psql_path, 'bin/pg_restore'),
                    '--dbname={}'.format(psql_config.dbname),
                    '--host={}'.format('localhost'),
                    '--port={}'.format(psql_config.port),
                    '--username={}'.format(psql_config.username),
                    '--password',
                    '--schema=public',
                    '--no-owner',
                    '--no-privileges',
                    '--section=pre-data',
                    '--section=data',
                    '--jobs=4',
                    '--exit-on-error',
                    dump_path,

                ],
                stdin=subprocess.PIPE,
                stdout=pl.stdout,
                stderr=pl.stderr,
            )
            restore.communicate(psql_config.password)
            exitcode = restore.wait()
            if exitcode:
                raise Exception(
                    'Failed to restore database from dump exitcode: {}'.format(
                        exitcode
                    )
                )

    def remove_old_tables(self, yt_client):
        if self.Parameters.left_only_n_snapshots <= 0:
            return
        snapshot_dir = self.Parameters.yt_base_directory
        folders = [folder for folder in yt_client.list(snapshot_dir) if folder != 'latest']
        folders.sort()

        for folder in folders[:-self.Parameters.left_only_n_snapshots]:
            yt_client.remove(
                os.path.join(snapshot_dir, folder),
                recursive=True,
                force=True
            )

    def get_encrypted_dump_path(self):
        if not self.Parameters.use_last_dump:
            result = str(
                sdk2.ResourceData(
                    sdk2.Resource[self.Parameters.psql_dump]
                ).path
            )
        else:
            dump_resource = sdk2.Resource.find(
                resources.AnswersPostgresqlDump,
                attrs={'env': self.Parameters.env},
                state=ctr.State.READY,
            ).order(
                -resources.AnswersPostgresqlDump.id,
            ).first()
            result = str(sdk2.ResourceData(dump_resource).path)
        logging.info('Using dump resource: %s', result)
        return result

    def on_execute(self):
        from yt.wrapper import YtClient
        settings = GpgSettings(
            key_owner='YASAP',
            secret_key_name='answers_pgp_private_key',
            public_key_name='answers_pgp_public_key',
            recipient=self.Parameters.gpg_key_owner,
        )

        self.setup_ramdrive()
        local_psql_path = str(
            sdk2.ResourceData(
                sdk2.Resource[self.Parameters.postgres_resource]
            ).path
        )
        encrypted_dump_path = self.get_encrypted_dump_path()
        local_dump_path = self.decrypt(encrypted_dump_path, settings)

        self.psql, psql_config = self.config_psql('psql_data', local_psql_path)

        yt_token = sdk2.Vault.data(self.Parameters.yt_token)
        yt = YtClient(proxy=self.Parameters.yt_proxy, token=yt_token)

        if not yt.exists(self.Parameters.yt_base_directory):
            yt.create('map_node', self.Parameters.yt_base_directory)

        with PsqlContextWrapper(self.psql):
            connection = get_pg_connection(psql_config)
            self.restore_pg(local_dump_path, psql_config, local_psql_path)

            yt_dir_path = os.path.join(
                self.Parameters.yt_base_directory, datetime.datetime.utcnow().isoformat()
            )
            yt.create('map_node', yt_dir_path)

            tables = self.fetch_tables(connection)
            tables.sort()

            for idx, table in enumerate(tables):
                if table in self.skip_tables:
                    logging.warning('Skip table: %s', table)
                    continue

                job = UploadJob(
                    table,
                    psql_config,
                    os.path.join(yt_dir_path, table),
                    yt_token,
                    self.Parameters.yt_proxy
                )
                logging.info('Process table %s (%d of %d)', table, idx + 1, len(tables))
                job.upload()

            latest_link_path = os.path.join(self.Parameters.yt_base_directory, 'latest')
            yt.link(yt_dir_path, latest_link_path, force=True)

            self.remove_old_tables(yt)

    def on_break(self, prev_status, status):
        if hasattr(self, 'psql'):
            self.psql._kill()

        if hasattr(self, 'pool'):
            self.pool.terminate()
