# coding: utf8

from collections import defaultdict
import json
import logging
import re

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
import sqlalchemy.sql.expression as sqlexpr
import geoalchemy2 as ga
import geoalchemy2.shape

import dateutil.parser
from shapely import wkb
from shapely.geometry import mapping
from flask import Blueprint, request, abort

from yandex.maps.wiki import utils, db, fastcgihelpers as fh
from yandex.maps.wiki import config
from yandex.maps.wiki.pgpool3 import get_pgpool, init_pgpool
from yandex.maps.wiki.utils import require, string_to_bool
from yandex.maps.wiki.utils import mercator_to_geodetic, mercator_distance_ratio

from yandex.maps.wiki.tasks import EM, register_task_type, grinder
from yandex.maps.wiki.tasks.models import Base, Task
from yandex.maps.wiki.tasks.stats import Statistics

from maps.wikimap.mapspro.libs.python import acl, revision
from maps.wikimap.mapspro.libs.python.validator import Validator, ValidatorConfig, Severity, RegionType, ResultsGateway, IssueCreator

import six
if six.PY3:
    long = int

TASK_NAME = 'validation'
ACL_PATH = 'mpro/tasks/validator'

HEAVY_VALIDATION_MIN_AREA = 2e12  # square mercator meters


def check_acl(uid):
    require(acl.is_permission_granted(get_pgpool(db.CORE_DB), ACL_PATH, uid),
            fh.ServiceException('forbidden', status='ERR_FORBIDDEN'))


@utils.threadsafe_memoize
def get_validator_config():
    editor_config_path = config.get_config().xml.get("services/editor/config")
    return ValidatorConfig(editor_config_path)


@utils.threadsafe_memoize
def get_validator():
    validator = Validator(get_validator_config())
    validator.init_modules()
    return validator


@utils.threadsafe_memoize
def get_all_checks():
    return sum([m.check_ids for m in get_validator().modules()], [])


@utils.threadsafe_memoize
def get_pgpool_validation_tasks():
    init_pgpool(config.get_config().databases[db.VALIDATION_DB], 'tasks')
    return get_pgpool(db.VALIDATION_DB)


@utils.threadsafe_memoize
def get_issue_creator():
    startrek_url = config.get_config().xml.get_attr("common/st", "api-base-url")
    logging.info("Startrek URL %s", startrek_url)
    return IssueCreator(startrek_url)


def login_by_uid(uid):
    return acl.login_by_uid(get_pgpool(db.CORE_DB), uid)


def create_results_gateway(task_id):
    return ResultsGateway(get_pgpool_validation_tasks(), task_id)


def create_grinder_gateway():
    return grinder.GrinderGateway(config.get_config().grinder_params.host)


def aoi_geom(aoi_ids, branch_id, commit_id):
    if len(aoi_ids) > 1:
        return None
    rgateway = revision.create_gateway(db.CORE_DB, branch_id)
    id = aoi_ids[0]
    aoi = rgateway.object_revision(id, commit_id)
    require(aoi,
            fh.ServiceException(
                'Object id: %s not found' % id, status='ERR_MISSING_OBJECT'))
    if 'cat:aoi' not in aoi.attrs:
        return None
    return aoi.geom


def message_ET(message_datum):
    message = message_datum.message()
    attributes = message.attributes

    result = EM.message(id=message_datum.id,
                        severity=attributes.severity.name.lower(),
                        check=attributes.check_id,
                        description=attributes.description,
                        region_type=attributes.region_type.name.lower(),
                        active=message_datum.is_active,
                        viewed=message_datum.is_viewed)

    exclusion = message_datum.exclusion_info
    if exclusion:
        created_at = dateutil.parser.parse(exclusion.created_at)
        result.append(
            EM.exclusion(created=created_at.isoformat(),
                         created_by=exclusion.created_by))

    if message.geom_wkb:
        try:
            geometry = wkb.loads(message.geom_wkb)
            result.append(
                EM.point(json.dumps(
                    utils.mercator_wkb_to_geojson(geometry.centroid.wkb))))

            if geometry.geom_type != 'Point':
                result.append(
                    EM.geometry(json.dumps(
                        utils.mercator_wkb_to_geojson(geometry.wkb))))
        except Exception as e:
            logging.exception(e)

    result.append(EM.oids(*[EM.oid(oid) for oid in message.oids[:10]]))
    return result


class Preset(Base):
    __tablename__ = 'preset'
    __table_args__ = {'schema': 'validation'}

    id = sa.Column(sa.BigInteger, primary_key=True)
    name = sa.Column(sa.String)
    is_public = sa.Column(sa.Boolean, default=False)
    created_at = sa.Column(sa.DateTime)
    created_by = sa.Column(sa.BigInteger)
    checks = sa.Column(postgresql.ARRAY(sa.String))

    @sa.orm.validates('name')
    def validate_name(self, key, name):
        name = name.strip()
        require(name,
                fh.ServiceException('Empty name', status='ERR_BAD_REQUEST'))
        return name

    @sa.orm.validates('checks')
    def validate_checks(self, key, checks):
        require(checks,
                fh.ServiceException('Empty checks', status='ERR_BAD_REQUEST'))
        all_checks = get_all_checks()
        for check in checks:
            require(check in all_checks,
                    fh.ServiceException(
                        "Wrong check name '%s'" % check,
                        status='ERR_WRONG_CHECK_NAME'))
        return checks

    def get_ET_brief(self):
        return EM.validation_preset(
            self.name,
            id=self.id,
            created_by=self.created_by)

    def get_ET_full(self):
        return EM.validation_preset(
            EM.name(self.name),
            EM.checks(*[EM.check(check) for check in self.checks]),
            id=self.id,
            created_by=self.created_by,
            created=self.created_at,
            public=self.is_public)


@register_task_type(name=TASK_NAME)
class Validation:
    @staticmethod
    def capabilities_ET():
        uid = int(request.values.get('uid', 0))
        check_acl(uid)
        with db.get_write_session(db.VALIDATION_DB) as session:
            query = session.query(Preset).filter(sa.or_(
                Preset.created_by == uid,
                Preset.is_public == sqlexpr.true()))
            presets = EM.validation_presets(
                *[p.get_ET_brief() for p in query])

        checks = EM.check(id='all')
        for module in get_validator().modules():
            checks.append(
                EM.check(*[EM.check(id=c) for c in module.check_ids], id=module.name))

        return EM.validation_task_type(presets, checks)

    @staticmethod
    def create(uid, request):
        check_acl(uid)
        task = ValidationTask()
        task.on_create(uid, request)

        checks_params_count = 0
        geom_params_count = 0
        if 'checks' in request.values:
            checks_params_count += 1
        if 'preset' in request.values:
            checks_params_count += 1
        if 'aoi' in request.values:
            geom_params_count += 1
        if 'region' in request.values:
            geom_params_count += 1
        if 'geometry' in request.values:
            geom_params_count += 1
        if 'task-id' in request.values:
            checks_params_count += 1
            geom_params_count += 1
        require(checks_params_count == 1,
                fh.ServiceException(
                    "Exactly one of the params 'checks', 'preset' and 'task-id' must be set",
                    status='ERR_BAD_REQUEST'))
        require(geom_params_count <= 1,
                fh.ServiceException(
                    "At most one of the params 'aoi', 'geometry', 'region' and 'task-id' can be set",
                    status='ERR_BAD_REQUEST'))
        require(('preset' not in request.values) or ('aoi-buffer' not in request.values),
                fh.ServiceException(
                    "Params 'preset' and 'aoi-buffer' can't be set simultaneously",
                    status='ERR_BAD_REQUEST'))

        preset_id = request.values.get('preset', 0)
        if preset_id:
            with db.get_write_session(db.VALIDATION_DB) as session:
                preset = session.query(Preset).get(preset_id)
                require(preset,
                        fh.ServiceException(
                            'Preset id %s is not found' % preset_id,
                            status='ERR_MISSING_PRESET'))
                require(preset.created_by == uid or preset.is_public,
                        fh.ServiceException(
                            'User %s is not allowed to use preset %s' % (uid, preset_id),
                            status='ERR_FORBIDDEN'))
                task.checks = preset.checks

                match = re.search(r'buffer=(\d+(\.\d+)?)', preset.name)
                if match:
                    task.aoi_buffer = float(match.group(1))
        elif 'checks' in request.values:
            task.checks = request.values['checks'].split(',')

        if 'geometry' in request.values:
            wkb_str = utils.geojson_to_mercator_wkb(str(request.values['geometry']))
            shape = wkb.loads(wkb_str)
            require(shape.geom_type == 'Polygon' and shape.is_valid,
                    fh.ServiceException('Invalid geometry', status='ERR_TOPO_INVALID_GEOMETRY'))
            task.aoi_geom = ga.shape.from_shape(shape, srid=3395)
        elif 'aoi' in request.values:
            task.aoi_ids = [long(x) for x in request.values['aoi'].split(',')]
        elif 'region' in request.values:
            task.region_id = long(request.values['region'])
        if 'aoi-buffer' in request.values:
            task.aoi_buffer = float(request.values['aoi-buffer'])

        branch_id = 0
        branch = request.values.get('branch', '0')
        if branch in ['trunk', 'approved', 'stable']:
            branch_mgr = revision.BranchManager(get_pgpool(db.CORE_DB))
            branch_id = branch_mgr.branch_id_by_type(branch.encode('utf-8'))
            require(branch_id is not None,
                    fh.ServiceException("Branch of requested type does not exist",
                                        status='ERR_BAD_REQUEST'))
        else:
            branch_id = long(branch)
        task.branch_id = branch_id

        task.only_changed_objects = string_to_bool(request.values.get('only-changed-objects', '0'))
        if task.only_changed_objects:
            revision_gateway = revision.RevisionsGateway(get_pgpool(db.CORE_DB), task.branch_id)
            require(revision_gateway.branch_type() in ['stable', 'archive'],
                    fh.ServiceException("Validaty by changed objects is not allowed in this branch",
                                        status='ERR_BAD_REQUEST'))

        old_task_id = request.values.get('task-id', 0)
        if old_task_id:
            with db.get_read_session(db.CORE_DB) as session:
                checks = create_results_gateway(long(old_task_id)).check_ids_with_fatal_errors()
                require(checks, fh.ServiceException(
                    'Task %s has no fatal errors' % old_task_id,
                    status='ERR_NO_FATAL_ERRORS'))
                old_task = session.query(ValidationTask).get(old_task_id)
                task.checks = checks
                task.aoi_geom = old_task.aoi_geom
                task.aoi_ids = old_task.aoi_ids
                task.region_id = old_task.region_id
                task.aoi_buffer = old_task.aoi_buffer
                task.branch_id = old_task.branch_id

        commit_id = long(request.values.get('commit', 0))
        if commit_id == 0:
            rgateway = revision.create_gateway(db.CORE_DB, task.branch_id)
            commit_id = rgateway.head_commit_id()
        task.commit_id = commit_id

        return task

    @staticmethod
    def launch(session, task_id, request):
        task = session.query(ValidationTask).get(task_id)
        if task.aoi_geom is not None:
            shape = ga.shape.to_shape(task.aoi_geom)
        elif task.aoi_ids:
            geom = aoi_geom(task.aoi_ids, task.branch_id, task.commit_id)
            shape = wkb.loads(geom) if geom else None
            task.aoi_geom = ga.shape.from_shape(shape, srid=3395) if shape else None
        else:
            shape = None

        args = {
            'uid': task.created_by,
            'taskId': task.id,
            'checks': task.checks,
            'branchId': task.branch_id,
            'commitId': task.commit_id
        }

        if task.parent_id:
            args['parentTaskId'] = task.parent_id

        is_heavy = True
        if shape:
            centroid = shape.centroid
            (_, lat) = mercator_to_geodetic(centroid.x, centroid.y)
            ratio = mercator_distance_ratio(lat)
            is_heavy = shape.area * ratio * ratio > HEAVY_VALIDATION_MIN_AREA
            args['aoiGeom'] = mapping(shape)
        if task.aoi_ids:
            args['aoiIds'] = task.aoi_ids
        if task.region_id:
            args['regionId'] = task.region_id
        if task.aoi_buffer:
            args['aoiBuffer'] = task.aoi_buffer

        if task.only_changed_objects:
            args['onlyChangedObjects'] = task.only_changed_objects

        task.is_heavy = is_heavy
        args['type'] = TASK_NAME + ('.heavy' if is_heavy else '')

        gateway = create_grinder_gateway()
        return gateway.submit(args)

    flask_blueprint = Blueprint(TASK_NAME, __name__)


class ValidationTask(Task):
    __tablename__ = 'validation_task'
    __table_args__ = {'schema': 'service'}
    __mapper_args__ = {'polymorphic_identity': 'validation'}

    id = sa.Column(sa.BigInteger, sa.ForeignKey('service.task.id'), primary_key=True)
    checks = sa.Column(postgresql.ARRAY(sa.String))
    branch_id = sa.Column(sa.BigInteger)
    commit_id = sa.Column(sa.BigInteger)
    aoi_geom = sa.Column(ga.Geometry('POLYGON', srid=3395))
    aoi_ids = sa.Column(postgresql.ARRAY(sa.BigInteger))
    region_id = sa.Column(sa.BigInteger)
    aoi_buffer = sa.Column(sa.Float, default=0.0)
    is_heavy = sa.Column(sa.Boolean)
    only_changed_objects = sa.Column(sa.Boolean)

    def context_aoi(self):
        aoi = EM.aoi()
        if self.aoi_ids:
            aoi.extend([EM.object(id=aoi_id, category_id='aoi') for aoi_id in self.aoi_ids])
        if self.region_id:
            aoi.append(EM.object(id=self.region_id, category_id='region'))
        if self.aoi_ids or self.region_id or (self.aoi_geom is not None):
            aoi.append(EM.buffer(self.aoi_buffer))
        return aoi

    def context_ET_brief(self, *args, **kwargs):
        ret = EM.validation_context(
            EM.branch(self.branch_id),
            EM.only_changed_objects(self.only_changed_objects))
        ret.append(self.context_aoi())
        return ret

    def context_ET_full(self, *args, **kwargs):
        aoi = self.context_aoi()
        if self.aoi_geom is not None:
            geometry_json = json.dumps(utils.geoalchemy_to_geojson(self.aoi_geom))
            aoi.append(EM.geometry(geometry_json))

        return EM.validation_context(
            EM.checks(*[EM.check(check)
                        for check in self.checks]),
            EM.branch(self.branch_id),
            EM.only_changed_objects(self.only_changed_objects),
            aoi)

    def result_ET_brief(self):
        try:
            results_gateway = create_results_gateway(self.id)
            return EM.validation_result(
                EM.messages(total_count=results_gateway.message_count()))
        except BaseException as e:
            logging.exception(e)

    def result_ET_full(self, page=1, per_page=10, *args, **kwargs):
        results_gateway = create_results_gateway(self.id)
        token = str(kwargs.get('token', ''))
        branch_id = long(kwargs.get('branch', 0))
        uid = long(kwargs.get('uid', 0))

        def severity_from_str(severity_str):
            if severity_str is None:
                return None

            severity_str = severity_str.upper()
            if severity_str not in Severity.names:
                abort(400, 'bad severity: %s' % severity_str)
            return Severity.names[severity_str]

        def region_type_from_str(region_type_str):
            if region_type_str is None:
                return None

            region_type_str = region_type_str.upper()
            if region_type_str not in RegionType.names:
                abort(400, 'unknown region value: %s' % region_type_str)
            return RegionType.names[region_type_str]

        def get_filters(kwargs):
            ret = {}
            if 'severity' in kwargs:
                ret['severity'] = severity_from_str(kwargs['severity'])
            if 'check' in kwargs:
                ret['check_id'] = str(kwargs['check'])
            if 'description' in kwargs:
                ret['description'] = str(kwargs['description'])
            if 'region_type' in kwargs:
                ret['region_type'] = region_type_from_str(kwargs['region_type'])

            return ret

        def stats_by_region_type():
            counts = dict((s.lower(), 0) for s in RegionType.names)
            for attrs, count in results_gateway.statistics():
                counts[attrs.region_type.name.lower()] += count

            return [Statistics('region-type', id=rt, count=c)
                    for rt, c in counts.items()]

        def stats_by_severity(region_type):
            counts = dict((s.lower(), 0) for s in Severity.names)
            for attrs, count in results_gateway.statistics(region_type=region_type):
                counts[attrs.severity.name.lower()] += count

            return [Statistics('severity', id=s, count=c)
                    for s, c in counts.items()]

        def stats_by_module(region_type, severity):
            stats_by_check = defaultdict(lambda: defaultdict(int))
            for attrs, count in results_gateway.statistics(region_type=region_type, severity=severity):
                stats_by_check[attrs.check_id][attrs.description] += count

            ret = [
                Statistics('check', id='base', children=[
                    Statistics(
                        'description', id=d, count=c)
                        for d, c in stats_by_check['base'].items()])]

            for module in get_validator().modules():
                children = [
                    Statistics(
                        'check', id=check, children=[
                            Statistics('description', id=d, count=c)
                            for d, c in stats_by_check[check].items()])
                    for check in module.check_ids if check in self.checks]
                if len(children):
                    ret.append(Statistics('check', id=module.name, children=children))

            return ret

        def statistics_ET(stats_by_region_type, stats_by_severity, stats_by_module):
            return EM.statistics(
                EM.region_types(*[item.get_ET() for item in stats_by_region_type]),
                EM.severities(*[item.get_ET() for item in stats_by_severity]),
                EM.checks(*[item.get_ET() for item in stats_by_module]))

        def messages_ET(page, per_page, **filters):
            total_count = results_gateway.message_count(**filters)

            per_page = int(per_page)
            page = fh.correct_page(page, per_page, total_count)
            offset = (page - 1) * per_page

            message_data = results_gateway.messages(
                get_pgpool(db.CORE_DB), token, branch_id,
                offset, per_page, uid, **filters)

            messages_el = EM.messages(page=page,
                                      per_page=per_page,
                                      total_count=total_count,
                                      *[message_ET(datum)
                                        for datum in message_data])
            if len(message_data) > 0:
                message_geoms = [wkb.loads(d.message().geom_wkb)
                                 for d in message_data
                                 if d.message().geom_wkb]
                if len(message_geoms) > 0:
                    bbox_obj = [utils.mercator_wkb_to_geojson(p.wkb)['coordinates']
                                for p in utils.bbox(message_geoms)]
                    bbox_obj = bbox_obj[0] + bbox_obj[1]
                    messages_el.append(EM.bbox(json.dumps(bbox_obj)))

            return messages_el

        severity = severity_from_str(kwargs.get('severity'))
        region_type = region_type_from_str(kwargs.get('region_type'))

        return EM.validation_result(
            statistics_ET(
                stats_by_region_type(),
                stats_by_severity(region_type),
                stats_by_module(region_type, severity)),
            messages_ET(page, per_page, **get_filters(kwargs)))


@Validation.flask_blueprint.route('/presets', methods=['GET'])
@db.write_session(db.VALIDATION_DB)
def get_presets(session):
    uid = int(request.values.get('uid', 0))
    check_acl(uid)

    page = int(request.values.get('page', 1))
    per_page = int(request.values.get('per-page', 10))

    query = session.query(Preset)
    if 'created-by' in request.values:
        created_by = int(request.values['created-by'])
        query = query.filter(Preset.created_by == created_by)
        if created_by != uid:
            query = query.filter(Preset.is_public == sqlexpr.true())
    else:
        query = query.filter(sa.or_(
            Preset.created_by == uid,
            Preset.is_public == sqlexpr.true()))

    page = fh.correct_page(page, per_page, query.count())
    offset = (page - 1) * per_page

    return fh.xml_response(
        EM.tasks(
            EM.response_validation_presets(
                EM.validation_presets(
                    page=page,
                    per_page=per_page,
                    total_count=query.count(),
                    *[p.get_ET_full() for p in query[offset:offset + per_page]]))))


@Validation.flask_blueprint.route('/presets', methods=['POST'])
@db.write_session(db.VALIDATION_DB)
def create_preset(session):
    uid = int(request.values['uid'])
    check_acl(uid)

    preset = Preset()
    preset.name = request.values['name']
    preset.is_public = string_to_bool(request.values['public'])
    preset.created_by = uid
    preset.created_at = utils.utcnow()
    preset.checks = request.values['checks'].split(',')

    try:
        session.add(preset)
        session.commit()
    except sa.exc.IntegrityError as e:
        logging.exception(e)
        raise fh.ServiceException(
            'Duplicate preset name %s' % preset.name,
            status='ERR_DUPLICATE_PRESET_NAME')

    return fh.xml_response(
        EM.tasks(
            EM.response_save_validation_preset(
                preset.get_ET_full())))


@Validation.flask_blueprint.route('/presets/<preset_id>', methods=['PUT'])
@db.write_session(db.VALIDATION_DB)
def change_preset(session, preset_id):
    uid = int(request.values['uid'])
    check_acl(uid)

    preset = session.query(Preset).get(preset_id)
    require(preset,
            fh.ServiceException(
                'Preset id %s is not found' % preset_id,
                status='ERR_MISSING_PRESET'))
    require(preset.created_by == uid or preset.is_public,
            fh.ServiceException(
                'User %s is not allowed to change preset %s' % (uid, preset_id),
                status='ERR_FORBIDDEN'))

    new_is_public = string_to_bool(request.values['public'])
    require(preset.created_by == uid or preset.is_public == new_is_public,
            fh.ServiceException(
                'User %s is not allowed to change public state of preset %s' % (uid, preset_id),
                status='ERR_FORBIDDEN'))

    preset.name = request.values['name']
    preset.checks = request.values['checks'].split(',')
    preset.is_public = new_is_public

    try:
        session.commit()
    except sa.exc.IntegrityError as e:
        logging.exception(e)
        raise fh.ServiceException(
            'Duplicate preset name %s %s' % (preset_id, preset.name),
            status='ERR_DUPLICATE_PRESET_NAME')

    return fh.xml_response(
        EM.tasks(
            EM.response_save_validation_preset(
                preset.get_ET_full())))


@Validation.flask_blueprint.route('/presets/<preset_id>', methods=['DELETE'])
@db.write_session(db.VALIDATION_DB)
def delete_preset(session, preset_id):
    uid = int(request.values['uid'])
    check_acl(uid)

    preset = session.query(Preset).get(preset_id)
    require(preset,
            fh.ServiceException(
                'Preset id %s is not found' % preset_id,
                status='ERR_MISSING_PRESET'))
    require(preset.created_by == uid,
            fh.ServiceException(
                'User %s is not allowed to delete preset %s' % (uid, preset_id),
                status='ERR_FORBIDDEN'))

    session.delete(preset)
    session.commit()

    return fh.xml_response(
        EM.tasks(
            EM.response_delete_validation_preset()))


@Validation.flask_blueprint.route('/messages/<message_id>/create-issue', methods=['POST'])
@db.write_session(db.CORE_DB)
def create_issue(session, message_id):
    uid = int(request.values['uid'])
    check_acl(uid)

    login = login_by_uid(uid)

    page_url = request.values['page-url']

    task_id = int(request.values['task-id'])
    task = session.query(ValidationTask).get(task_id)
    require(task,
            fh.ServiceException(
                'Task id %s is not found' % task_id,
                status='ERR_BAD_REQUEST'))
    branch_id = task.branch_id

    issue_creator = get_issue_creator()
    key = issue_creator.get_or_create_issue(
        get_pgpool_validation_tasks(),
        message_id.encode('utf-8'),
        branch_id,
        login,
        page_url.encode('utf-8'))

    return fh.xml_response(
        EM.tasks(
            EM.response_validation_create_issue(
                EM.key(key))))


@Validation.flask_blueprint.route('/messages/<message_id>/view', methods=['POST', 'GET'])
@db.write_session(db.CORE_DB)
def message_set_viewed(session, message_id):
    uid = int(request.values['uid'])
    check_acl(uid)

    task_id = int(request.values['task-id'])
    task = session.query(ValidationTask).get(task_id)
    require(task,
            fh.ServiceException(
                'Task id %s is not found' % task_id,
                status='ERR_BAD_REQUEST'))
    branch_id = task.branch_id
    results_gateway = create_results_gateway(task_id)
    datum = results_gateway.message_set_viewed(
        get_pgpool(db.CORE_DB),
        branch_id,
        uid,
        message_id.encode('utf-8'))

    return fh.xml_response(
        EM.response_validation_mark_viewed(
            message_ET(datum)))
