import logging
import subprocess

from sandbox import sdk2
from sandbox.sandboxsdk import environments
from sandbox.sandboxsdk.process import run_process, check_process_return_code
import sandbox.common.types.client as ctc
import sandbox.common.types.misc as ctm

PG_SCHEMA_NAME = "mapsearch"
PG_TABLE_NAME = "omission_cluster"

DROP_SCHEMA_QUERY = "DROP SCHEMA IF EXISTS {} CASCADE;".format(PG_SCHEMA_NAME)
CREATE_SCHEMA_QUERY = "CREATE SCHEMA {};".format(PG_SCHEMA_NAME)
CREATE_TABLE_QUERY = "CREATE TABLE {}.{}"\
                     "(id_cluster SERIAL PRIMARY KEY, business_service VARCHAR NULL, service_region VARCHAR NULL, shape GEOMETRY);"\
                     .format(PG_SCHEMA_NAME, PG_TABLE_NAME) +\
                     "CREATE INDEX ON {}.{} USING GIST(shape);".format(PG_SCHEMA_NAME, PG_TABLE_NAME)
INSERT_QUERY_HEADER = "INSERT INTO {}.{} (business_service, service_region, shape) VALUES ".format(PG_SCHEMA_NAME, PG_TABLE_NAME)

PSQL_BIN = '/usr/bin/psql'
PSQL_HOST = {
    'stable': 'sas-cjxkqfmf1m1fy0c5.db.yandex.net',
    'testing': 'c-mdbh6cqv98bv2p8igsqc.rw.db.yandex.net',
    'development': 'pg94.maps.dev.yandex.net'
}
PSQL_PORT = {
    'stable': '6432',
    'testing': '6432',
    'development': '5432'
}
PSQL_DB = {
    'stable': 'mapsfactory',
    'testing': 'mapsfactory_data',
    'development': 'mapsfactory_data'
}
PSQL_USER = {
    'stable': 'mapscontent',
    'testing': 'mapsfactory',
    'development': 'mapsfactory'
}
PSQL_PSW_VAULT_KEY = {
    'stable': 'mapsfactory-stable-mapsearch-psw3',
    'testing': 'mapsfactory-testing-mapsearch-psw',
}
PSQL_DEFAULT_PSW = 'mapsearch'


class UpdateDataOmissionClusters(sdk2.Task):
    """
    Task for update map data omission clusters
    """

    class Requirements(sdk2.Task.Requirements):
        environments = [environments.PipEnvironment('yandex-yt')]
        privileged = True
        client_tags = ctc.Tag.GENERIC | ctc.Tag.LXC
        dns = ctm.DnsType.DNS64

    class Parameters(sdk2.Task.Parameters):
        with sdk2.parameters.RadioGroup("Output DB environment") as db_environment:
            db_environment.values['stable'] = db_environment.Value(value='Stable', default=True)
            db_environment.values['testing'] = db_environment.Value(value='Testing')
            db_environment.values['development'] = db_environment.Value(value='Development')

        _container = sdk2.parameters.Container(
            "Environment container resource",
            platform="linux_ubuntu_16.04_xenial",
            required=True
        )

    def on_execute(self):
        logging.info('Running...')
        yt_token = sdk2.Vault.data(self.owner, 'yt-token')
        yt_table = '//home/data_coverage/services/statclusters'

        rows = self.read_yt_table(yt_token, yt_table)
        logging.info('Read {} yt-rows'.format(len(rows)))

        run_process(['apt-get', 'update'], log_prefix='apt-get')
        run_process(['apt-get', 'install', '--no-install-recommends', '--yes', 'postgresql-client'], log_prefix='apt-get')

        self.write_to_postgres(rows)
        logging.info('Running... done')

    def read_yt_table(self, token, table):
        from yt.wrapper import YtClient
        client = YtClient('hahn', token)
        t = client.read_table(table, format='dsv', raw=False)
        return [(item['geometry'], item['group'], item['set_name']) for item in t]

    def write_to_postgres(self, rows):
        psql = self.run_psql()
        query_input = psql.stdin
        query_input.write(DROP_SCHEMA_QUERY)
        query_input.write(CREATE_SCHEMA_QUERY)
        query_input.write(CREATE_TABLE_QUERY)
        query_input.write(INSERT_QUERY_HEADER)
        first_row = True
        for row in rows:
            if not first_row:
                query_input.write(',')
            query_input.write("('{}', '{}', ST_Transform(ST_GeomFromText('{}', 4326), 3395))".format(
                row[2].replace("'", ""), row[1].replace("'", ""), row[0].replace("'", "")
            ))
            first_row = False
        query_input.write(';')
        query_input.close()
        logging.debug('Waiting psql...')
        psql.wait()
        logging.debug('Waiting psql... done')
        check_process_return_code(psql)

    def run_psql(self):
        env = self.Parameters.db_environment
        host = PSQL_HOST[env]
        port = PSQL_PORT[env]
        password = PSQL_DEFAULT_PSW
        if PSQL_PSW_VAULT_KEY.has_key(env):
            password = sdk2.Vault.data(self.owner, PSQL_PSW_VAULT_KEY[env])

        return run_process(
            [
                PSQL_BIN,
                '--host', host,
                '--port', port,
                '--db', PSQL_DB[env],
                '--user', PSQL_USER[env],
                '-v', 'ON_ERROR_STOP=1',
                '-1',
                '-X'
            ],
            log_prefix='psql',
            stdin=subprocess.PIPE,
            wait=False,
            outs_to_pipe=False,
            environment={ 'PGPASSWORD': password }
        )
