import pytest
from unittest.mock import patch
from zenyatta.db.sql import SQL, PostgresSQL, MySQL

from airflow.models import Connection


def airflow_connection_mock():
    return Connection(login='test', password='password',
                      host='localhost', port='5432',
                      schema='sitedb')


def test_sql():
    sql = SQL('test')
    with pytest.raises(NotImplementedError):
        sql.engine_string()


def test_create_driver():
    driver = SQL.create_sql_driver('Postgres')
    assert type(driver) is PostgresSQL
    driver = SQL.create_sql_driver('postgres')
    assert type(driver) is PostgresSQL
    driver = SQL.create_sql_driver('MySQL')
    assert type(driver) is MySQL
    driver = SQL.create_sql_driver('mysql')
    assert type(driver) is MySQL
    driver = SQL.create_sql_driver('Aurora')
    assert type(driver) is MySQL


@patch('zenyatta.db.sql.get_airflow_connection', return_value=airflow_connection_mock())
def test_postgres_driver(conn_mock):
    driver = SQL.create_sql_driver('Postgres')
    engine_string = driver.engine_string()
    assert engine_string.startswith('postgres')
    assert 'test:password@localhost:5432/sitedb' in engine_string

    engine_string = driver.engine_string(host='notlocalhost', port=1234)
    assert 'test:password@notlocalhost:1234/sitedb' in engine_string


@patch('zenyatta.db.sql.get_airflow_connection', return_value=airflow_connection_mock())
def test_mysql_driver(conn_mock):
    driver = SQL.create_sql_driver('mysql')
    engine_string = driver.engine_string()
    assert engine_string.startswith('mysql+pymysql')
    assert 'test:password@localhost:5432/sitedb' in engine_string

    engine_string = driver.engine_string(host='notlocalhost', port=1234)
    assert 'test:password@notlocalhost:1234/sitedb' in engine_string
