import yatest.common
import yatest.common.runtime
import os
import time
import shutil
import socket
import signal
import pymysql
import stat
import uuid
import contextlib

from direct.perl.perltests.utils.utils import log

CONFIG_TEMPLATE = """---
db_config:
  utf8: 1
  connect_timeout: 60
  user: root
  pass: ''
  AutoCommit: 1
  host: {0}
  port: {1}
  CHILDS:
    fakebalance: {{}}
    fakeblackbox: {{}}
    monitor: {{}}
    ppc:
        CHILDS:
            1: {{}}
            2: {{}}
    ppcdict: {{}}
    ppchouse: {{}}
    ppclog: {{}}
    ppcpricelog: {{}}
    rbac: {{}}
    unit_tests: {{}}
    sharded_unit_tests:
      CHILDS:
        1: {{}}
        2: {{}}
"""

mysql_listen_addr = "127.210.213.90"


def find_free_port():
    with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        return s.getsockname()[1]


class Mysql():
    def write_config(self):
        log('Creating db-tools.yaml...')
        with open(yatest.common.runtime.work_path("db-tools.yaml"), "w") as f:
            config = CONFIG_TEMPLATE.format(mysql_listen_addr, self.port)
            print(config)
            f.write(config)

        os.environ['UNIT_TEST_DB_CONFIG_PATH'] = yatest.common.runtime.work_path("db-tools.yaml")

    def wait_for_connection(self, timeout):
        log("wait for connection")
        start_time = time.time()
        while True:
            try:
                sock = socket.socket()
                sock.connect((mysql_listen_addr, self.port))
                log("connection acquired")
                return
            except socket.error as exc:
                log(exc)
                if time.time() < start_time + timeout:
                    time.sleep(1)
                else:
                    raise exc

    def my_connect(self):
        return pymysql.connect(
            host=mysql_listen_addr,
            port=self.port,
            user='root',
            program_name='test',
        )

    def prepare_data_dir(self):
        data_dir_ro = yatest.common.runtime.work_path("mysql-test-data")
        data_dir = yatest.common.runtime.work_path("mysql-test-data-rw")
        shutil.copytree(data_dir_ro, data_dir)

        for root, dirs, files in os.walk(data_dir):
            for file_name in files:
                extracted_path = os.path.join(root, file_name)
                os.chmod(extracted_path, os.stat(extracted_path).st_mode | stat.S_IWRITE)

        return data_dir

    def start_mysql(self):
        server_dir = yatest.common.runtime.work_path("mysql-server")
        data_dir = self.prepare_data_dir()

        self.port = find_free_port()

        cmd = [
            server_dir + "/ld-linux-x86-64.so.2",
            "--inhibit-cache",
            "--library-path", server_dir,
            server_dir + "/sbin/mysqld",
            "--no-defaults",
            "--basedir", server_dir,
            "--lc-messages-dir", server_dir,
            "--innodb_use_native_aio=0",

            # "--log-error=" + server_dir + "/mysql.err",

            "--character-set-server=utf8",
            "--skip-name-resolve",
            "--innodb-flush-method=nosync",
            "--innodb-buffer-pool-dump-at-shutdown=OFF",
            "--innodb-buffer-pool-load-at-startup=OFF",
            "--innodb_log_file_size=1M",
            "--innodb-file-per-table=OFF",
            "--innodb-stats-persistent=OFF",
            "--innodb-doublewrite=0",
            "--innodb-flush-log-at-trx-commit=0",
            "--innodb-monitor-disable=all",

            "--performance-schema=OFF",
            "--sync-frm=OFF",
            "--skip-ssl",
            "--default-time-zone=Europe/Moscow",
            "--datadir="+data_dir,
            "--socket=",
            "--bind-address="+mysql_listen_addr,
            "--port="+str(self.port),
            "--secure-file-priv="
        ]

        self.write_config()

        log("Running {0}".format(cmd))

        mp = yatest.common.execute(cmd, wait=False)
        self.wait_for_connection(20)

        self.pid = mp.process.pid

        return None

    def stop_mysql(self):
        os.kill(self.pid, signal.SIGTERM)


def yield_create_databases(mysql_connection):
    db_name = "unit_tests_"+str(uuid.uuid4()).replace('-', '_')

    dbs = [db_name]

    mysql_connection.cursor().execute('create database ' + db_name + ' default charset utf8')

    for i in range(1, 5):
        shard_name = db_name + '_' + str(i)
        mysql_connection.cursor().execute('create database ' + shard_name + ' default charset utf8')
        dbs.append(shard_name)

    yield db_name

    for db in dbs:
        mysql_connection.cursor().execute('drop database ' + db)


def yield_mysql_connection():

    mysql_obj = Mysql()

    mysql_obj.start_mysql()
    connection = mysql_obj.my_connect()
    connection.cursor().execute("set global sql_mode=''")

    yield connection

    connection.close()
    mysql_obj.stop_mysql()
