from __future__ import print_function
import converters

from collections import namedtuple, defaultdict
from functools import partial
import logging
import psycopg2
import subprocess as sp
from StringIO import StringIO
import sys
from tempfile import NamedTemporaryFile
from traceback import format_exc
import yaml
import pytest
from osgeo import ogr, osr

import yatest.common

from yandex.maps.pgpool3 import PgPool, PoolConstants
from maps.pylibs.local_postgres import postgres_instance
from maps.wikimap.mapspro.libs.python import cpplogger
from maps.wikimap.mapspro.libs.python import revision
from maps.wikimap.mapspro.libs.python.validator import Validator, ValidatorConfig

TRUNK_BRANCH_ID = 0

REVISION_SCHEMA_PATH = 'maps/wikimap/mapspro/libs/revision/sql/postgres_upgrade.sql'
REVISION_DB_SCHEMA = 'revision'

SR4326 = osr.SpatialReference()
SR4326.ImportFromEPSG(4326)
SR3395 = osr.SpatialReference()
SR3395.ImportFromEPSG(3395)
FROM_4326_TO_3395 = osr.CoordinateTransformation(SR4326, SR3395)


def wgs84_wkt_to_mercator_wkb(wkt):
    geometry = ogr.CreateGeometryFromWkt(wkt)
    geometry.Transform(FROM_4326_TO_3395)
    return geometry.ExportToWkb()


_conns = {}


def get_conn(params):
    conn_string = params.connection_string
    return _conns.setdefault(conn_string, psycopg2.connect(conn_string))


def init_db(conn_params):
    conn = get_conn(conn_params)
    cursor = conn.cursor()
    cursor.execute('CREATE EXTENSION postgis')
    cursor.execute('CREATE EXTENSION hstore')
    conn.commit()


@pytest.fixture(scope="module")
def postgres():
    cpplogger.init_logger()
    with postgres_instance() as conn_params:
        init_db(conn_params)
        yield conn_params


def create_pgpool(conn_params):
    poolConstants = PoolConstants(1, 1, 2, 2)
    poolConstants.wait_for_availability_info = True
    conn_instance = ('localhost', conn_params.port)
    logger = logging.getLogger('pgpool3')
    logger.setLevel(logging.ERROR)

    return PgPool(
        conn_instance,
        [conn_instance],
        conn_params.connection_string,
        poolConstants,
        logger
    )


def apply_patch(fname, conn):
    cursor = conn.cursor()
    cursor.execute('SET SEARCH_PATH=%s,public' % REVISION_DB_SCHEMA)
    with open(fname) as patch_file:
        cursor.execute(patch_file.read())
    cursor.execute(
        "DELETE FROM attributes_relations;"
        "INSERT INTO attributes_relations "
        "   SELECT * FROM attributes "
        "    WHERE id IN (SELECT attributes_id FROM object_revision_relation);")
    conn.commit()
    logging.info('Patched %s schema with %s', REVISION_DB_SCHEMA, fname)


def reset_revisions(conn_params):
    conn = get_conn(conn_params)
    cursor = conn.cursor()
    cursor.execute('DROP SCHEMA IF EXISTS %s CASCADE' % REVISION_DB_SCHEMA)
    sql = yatest.common.source_path(REVISION_SCHEMA_PATH)
    with open(sql) as sql_file:
        cursor.execute(sql_file.read())
    conn.commit()
    logging.info('YmapsDF schema %s reset', REVISION_DB_SCHEMA)


def pgpool_cfg(conn_params):
    return """<?xml version="1.0" encoding="utf-8"?>
<config>
    <common>
        <databases>
            <database id="long-read" name="{0.dbname}">
                <write host="localhost" port="{0.port}" user="{0.user}" pass="{0.password}"/>
                <read host="localhost" port="{0.port}" user="{0.user}" pass="{0.password}"/>
                <pools nearestDC="1" failPingPeriod="5" pingPeriod="30" timeout="5">
                    <revisionapi writePoolSize="2" writePoolOverflow="0" readPoolSize="2" readPoolOverflow="0"/>
                </pools>
            </database>
        </databases>
    </common>
</config>
""".format(conn_params)

DEVNULL = open('/dev/null', 'w')


def json_to_revision(fname, conn_params):
    reset_revisions(conn_params)
    with NamedTemporaryFile() as rev_cfg:
        rev_cfg.write(pgpool_cfg(conn_params))
        rev_cfg.flush()
        with open(fname) as json_file:
            tool = yatest.common.binary_path('maps/wikimap/mapspro/tools/revisionapi/revisionapi')
            sp.check_call([tool, '--cmd=import', '--branch=trunk',
                           '--user-id=1', '--start-from-json-id',
                           '--cfg=%s' % rev_cfg.name],
                           stdin=json_file, stdout=DEVNULL)


PASS = 'PASS'
FAIL = 'FAIL'
EXCEPTION = 'EXCEPTION'


class TestResult(object):
    def __init__(self, status, info):
        self.status = status
        self.info = info

    def __nonzero__(self):
        return self.status == PASS

    def __repr__(self):
        return self.status


ExpectedMessage = namedtuple('ExpectedMessage', 'description oids')


def compare(got, expected):
    got_counts = defaultdict(lambda: 0)
    for message in got:
        got_counts[message] += 1
    for messages in expected:
        for message in messages:
            if got_counts[message]:
                got_counts[message] -= 1
                break
        else:
            got_counts[messages[0]] -= 1

    return ([m for m, c in got_counts.items() if c < 0],
            [m for m, c in got_counts.items() if c > 0])


def perform_test(conn_params, validator, check, data_file, actions_list, expected_list, aoi_wkt):
    """
    Выполняет тест на исходных данных data_file, со списком ожидаемых
    сообщений expected_list по опциональной области aoi_wkt заданной в
    WGS84 WKT.

    action_list содержит функции которые применяются перед запуском
    теста к данным *в порядке их добавления*. Существует два типа
    таких функций:
    - patch - любой SQL-запрос изменяющий состояние базы данных; и
    - converter - меняет один тип объектов на другой.

    Возвращает пару (статус, информация)
    """
    def print_messages(messages, stream=sys.stdout, indent=0):
        res = StringIO()
        yaml.dump([{'description': m.description, 'oids': list(m.oids)}
                   for m in messages], res)
        res.seek(0)
        for line in res:
            print(' ' * indent + line, file=stream)

    try:
        json_to_revision(data_file, conn_params)

        for action in actions_list:
            action(get_conn(conn_params))

        # создаём каждый раз заново, чтобы избежать ошибки
        # "cached plan must not change result type"
        pgpool = create_pgpool(conn_params)
        head_commit_id = (revision.RevisionsGateway(pgpool, TRUNK_BRANCH_ID)
                          .head_commit_id())

        validator_args = [[check], pgpool, TRUNK_BRANCH_ID, head_commit_id]
        if aoi_wkt:
            validator_args.append(wgs84_wkt_to_mercator_wkb(aoi_wkt))  # aoi_wkb
            validator_args.append('./coverage-test')  # aoi_coverage_dir
            validator_args.append(0.0)  # aoi_buffer

        got = [ExpectedMessage(m.attributes.description, frozenset(m.oids))
               for m in validator.run(*validator_args).messages()]
        missing, excess = compare(got, expected_list)
    except Exception:
        return TestResult(EXCEPTION, format_exc().rstrip())

    if len(missing) + len(excess) == 0:
        return TestResult(PASS, '')
    else:
        info = StringIO()
        if missing:
            print('The following items were expected but not reported: ', file=info)
            print_messages(missing, stream=info, indent=2)
        if excess:
            print('The following unexpected items were reported: ', file=info)
            print_messages(excess, stream=info, indent=2)
        return TestResult(FAIL, info.getvalue().rstrip())


TestCase = namedtuple('TestCase', 'id data actions expected aoi_wkt')


def data_path(relpath):
    return yatest.common.source_path('maps/wikimap/mapspro/services/tasks/validator-checks/tests/' + relpath)


class InvalidActionError(Exception):
    pass


def load_actions(actions_node):
    if actions_node is None:
        return []

    result = []

    for child in actions_node:
        if len(child) != 1:
            raise InvalidActionError(
                'Each item should contain one action, %s found: %s'
                % (len(child), child.keys()))
        if 'patch' in child:
            result.append(partial(apply_patch, data_path(child['patch'])))
        elif 'convert' in child:
            result.append(partial(converters.convert, child['convert']))
        else:
            raise InvalidActionError('Unknown action: %s' % child.keys()[0])

    return result


def load_messages(expected_node):
    if not expected_node:
        return []

    result = []

    for child in expected_node:
        if 'include' in child:
            with open(data_path(child['include'])) as include_file:
                result += load_messages(yaml.load(include_file))
        elif 'any' in child:
            result.append(sum(load_messages(child['any']), []))
        else:
            result.append([ExpectedMessage(child['description'],
                           frozenset(child['oids']))])

    return result


def load_suite(stream):
    return [TestCase(case.get('case'), case['data'],
                     load_actions(case.get('actions')),
                     load_messages(case['expected']),
                     case.get('aoi'))
            for case in yaml.load(stream)]


def run_suite(conn_params, validator, check, suite_fname):
    """
    Выполняет тесты из файла suite_fname.

    Возвращает пару (статус, информация)
    """
    def format_case_info(case_id, case_info):
        if case_id is None:
            return case_info
        return ('In test case ' + case_id + ':\n'
                '  ' + case_info.replace('\n', '\n  '))

    status = PASS
    info = StringIO()
    try:
        with open(suite_fname) as suite_file:
            suite = load_suite(suite_file)
    except Exception:
        return TestResult(EXCEPTION, 'While loading ' + suite_fname + '\n'
                           + format_exc().rstrip())

    for case in suite:
        result = perform_test(
            conn_params, validator, check, data_path(case.data),
            case.actions, case.expected, case.aoi_wkt)

        if result.status == EXCEPTION:
            return TestResult(EXCEPTION, format_case_info(case.id, result.info))
        elif result.status == FAIL:
            status = FAIL
            print(format_case_info(case.id, result.info), file=info)

    return TestResult(status, info.getvalue().rstrip())


@pytest.fixture(scope="module")
def validator_config():
    editor_config_path = yatest.common.source_path('maps/wikimap/mapspro/cfg/editor/editor.xml')
    validator_config = ValidatorConfig(editor_config_path)
    return validator_config


@pytest.fixture(scope="module")
def validator(validator_config):
    validator = Validator(validator_config)
    validator.init_modules()
    return validator


def test_suite(postgres, validator, check):
    suite = data_path('suites/' + check + '.yaml')
    result = run_suite(postgres, validator, check, suite)
    if result.status != PASS:
        logging.error(result.info)
    assert result
