# coding: utf8

from collections import defaultdict
import logging
from lxml import etree as ET
import six
import sqlalchemy as sa
import shapely
from flask import Blueprint, request, abort

from yandex.maps.wiki import utils, db, fastcgihelpers as fh
from yandex.maps.wiki import config
from yandex.maps.wiki.pgpool3 import get_pgpool, init_pgpool

from yandex.maps.wiki.tasks import EM, register_task_type, grinder
from yandex.maps.wiki.tasks.models import Base, Task
from yandex.maps.wiki.tasks.stats import Statistics
from yandex.maps.wiki.utils import require, string_to_bool

from maps.wikimap.mapspro.libs.python import acl
from maps.wikimap.mapspro.libs.python.diffalert import ResultsViewer, StoredMessage, PostponeAction, SortKind, IssueCreator
from maps.wikimap.mapspro.libs.python import revision

TASK_NAME = 'diffalert'

ACL_PATH = 'mpro/tasks/' + TASK_NAME

BY_SIZE = 'by_size'
BY_NAME = 'by_name'

USER_FILTER_COMMON = 'common'
USER_FILTER_COMMON_OR_OUTSOURCERS = 'common-or-outsourcer'
USER_FILTER_ALL = 'all'

ALLOWED_USER_FILTERS = [
    USER_FILTER_COMMON,
    USER_FILTER_COMMON_OR_OUTSOURCERS,
    USER_FILTER_ALL]

BRANCHES_LIMIT = 20


@utils.threadsafe_memoize
def get_pgpool_validation_tasks():
    init_pgpool(config.get_config().databases[db.VALIDATION_DB], 'tasks')
    return get_pgpool(db.VALIDATION_DB)


@utils.threadsafe_memoize
def get_issue_creator():
    startrek_url = config.get_config().xml.get_attr("common/st", "api-base-url")
    logging.info("Startrek URL %s", startrek_url)
    return IssueCreator(startrek_url)


def create_results_viewer(task_id):
    return ResultsViewer(task_id, get_pgpool_validation_tasks())


def create_grinder_gateway():
    return grinder.GrinderGateway(config.get_config().grinder_params.host)


def get_snapshot_id(branch_id):
    rgateway = revision.create_gateway(db.CORE_DB, branch_id)
    branch_type = rgateway.branch_type()
    require(
        branch_type != revision.BRANCH_TYPE_APPROVED,
        fh.ServiceException(
            'Unsupported branch type: %s' % branch_type,
            status='INVALID_OPERATION'))
    return branch_id, rgateway.head_commit_id()


def load_category_groups():
    category_groups = {}
    editor_config = ET.parse(config.get_config().editor_params.config)
    editor_config.xinclude()
    for cg_node in editor_config.findall('/category-groups/category-group'):
        cat_ids = category_groups.setdefault(cg_node.attrib['id'], [])
        for cat_node in cg_node.findall('categories/category'):
            cat_ids.append(cat_node.attrib['id'])
    return category_groups


def aoi_geom_wkb(aoi_id, branch_id):
    rgateway = revision.create_gateway(db.CORE_DB, branch_id)
    aoi = rgateway.object_revision(aoi_id, rgateway.head_commit_id())
    if aoi is None or 'cat:aoi' not in aoi.attrs:
        raise fh.ServiceException('Aoi id: %s not found' % aoi_id,
                                  status='ERR_MISSING_OBJECT')
    assert(aoi.geom is not None)
    return aoi.geom


def bbox_geom(bbox_str):
    def throw_wrong_format():
        raise fh.ServiceException('wrong bbox format: '
                                  + bbox_str, status='ERR_BAD_REQUEST')
    try:
        coords = map(float, bbox_str.split(','))
    except ValueError:
        throw_wrong_format()
    if len(coords) != 4:
        throw_wrong_format()

    min_x, min_y = utils.geodetic_to_mercator(coords[0], coords[1])
    max_x, max_y = utils.geodetic_to_mercator(coords[2], coords[3])
    return shapely.geometry.box(min_x, min_y, max_x, max_y)


def check_acl(uid):
    require(
        acl.is_permission_granted(get_pgpool(db.CORE_DB), ACL_PATH, uid),
        fh.ServiceException('forbidden', status='ERR_FORBIDDEN'))


def login_by_uid(uid):
    return acl.login_by_uid(get_pgpool(db.CORE_DB), uid)


def extract_filters(category_groups, kwargs):
    if ('preset' in kwargs):
        require(
            'aoi' not in kwargs and 'region_priority' not in kwargs and 'major_priority' not in kwargs and
            'category_group' not in kwargs and 'description' not in kwargs,
            fh.ServiceException(
                "When 'preset' is set, no one of params 'aoi', 'region-priority', 'major-priority', " +
                "'category-group' and 'description' can be set",
                status='ERR_BAD_REQUEST'))

    ret = {}

    if 'preset' in kwargs:
        with db.get_write_session(db.VALIDATION_DB) as session:
            preset_id = kwargs['preset']
            preset = session.query(Preset).get(preset_id)
            require(preset,
                    fh.ServiceException(
                        'Preset id %s is not found' % preset_id,
                        status='ERR_MISSING_PRESET'))
            kwargs.update(preset.get_attributes())

    branch_id = int(kwargs.get('branch', 0))
    if 'bb' in kwargs and 'aoi' in kwargs:
        aoi_geom = shapely.wkb.loads(
            aoi_geom_wkb(int(kwargs['aoi']), branch_id))
        geom = aoi_geom.intersection(bbox_geom(str(kwargs['bb'])))
        if isinstance(geom, shapely.geometry.base.BaseMultipartGeometry):
            geom = shapely.geometry.MultiPolygon(
                [g for g in geom.geoms
                 if isinstance(g, shapely.geometry.Polygon)])

        ret['geom_wkb'] = shapely.wkb.dumps(geom)
    elif 'bb' in kwargs:
        ret['geom_wkb'] = shapely.wkb.dumps(bbox_geom(str(kwargs['bb'])))
    elif 'aoi' in kwargs:
        ret['geom_wkb'] = aoi_geom_wkb(int(kwargs['aoi']), branch_id)

    if 'major_priority' in kwargs:
        req_priority = kwargs['major_priority']
        try:
            req_priority = int(req_priority)
        except ValueError:
            abort(400, 'bad priority: %s' % req_priority)
        ret['major_priority'] = req_priority

    if 'category_group' in kwargs:
        req_group = kwargs['category_group']
        if req_group not in category_groups:
            abort(400, 'unknown category group: %s' % req_group)
        ret['categories'] = category_groups[req_group]

    if 'description' in kwargs:
        ret['description'] = str(kwargs['description'])

    if 'region_priority' in kwargs:
        region_priority = str(kwargs['region_priority'])
        try:
            region_priority = int(region_priority)
        except ValueError:
            abort(400, 'bad region priority: %s' % region_priority)
        ret['region_priority'] = region_priority

    if 'postponed' in kwargs:
        ret['postponed'] = string_to_bool(kwargs['postponed'])

    if 'exclude_inspected_by' in kwargs:
        ret['exclude_inspected_by'] = int(kwargs['exclude_inspected_by'])

    return ret


def message_ET(message):
    message_et = EM.message(
        EM.oid(message.object_id),
        id=message.id,
        major_priority=message.priority.major,
        minor_priority=message.priority.minor,
        description=message.description,
        postponed=message.postponed)
    if message.object_label:
        message_et.set('object-label', message.object_label.decode('utf-8'))
    if message.inspected_by:
        message_et.set('inspected-by', str(message.inspected_by))
        message_et.set('inspected-at', message.inspected_at)
    return message_et


def branch_ET(branch, production_branch):
    if branch.type == revision.BRANCH_TYPE_APPROVED:
        raise fh.ServiceException(
            'got approved branch', status='INTERNAL_ERROR')

    branch_et = EM.branch(
        id=branch.id,
        type=branch.type if branch != production_branch else "production",
        state=branch.state)
    if branch.type != revision.BRANCH_TYPE_TRUNK:
        branch_et.append(EM.created(branch.created_at))
        branch_et.append(EM.created_by(branch.created_by))
    if branch.type in [revision.BRANCH_TYPE_ARCHIVE, revision.BRANCH_TYPE_DELETED]:
        branch_et.append(EM.published(branch.finished_at))
        branch_et.append(EM.published_by(branch.finished_by))
    return branch_et


class Preset(Base):
    __tablename__ = 'preset'
    __table_args__ = {'schema': 'diffalert'}

    id = sa.Column(sa.BigInteger, primary_key=True)
    name = sa.Column(sa.String)
    created_at = sa.Column(sa.DateTime)
    created_by = sa.Column(sa.BigInteger)
    aoi_id = sa.Column(sa.BigInteger)
    pri_region = sa.Column(sa.Integer)
    pri_major = sa.Column(sa.Integer)
    category_group = sa.Column(sa.String)
    description = sa.Column(sa.String)

    @sa.orm.validates('name')
    def validate_name(self, key, name):
        name = name.strip()
        require(name, fh.ServiceException('Empty name', status='ERR_BAD_REQUEST'))
        return name

    def get_attributes(self):
        ret = {}
        ret['aoi'] = self.aoi_id

        if self.pri_region is not None:
            ret['region_priority'] = self.pri_region
        if self.pri_major is not None:
            ret['major_priority'] = self.pri_major
        if self.category_group:
            ret['category_group'] = self.category_group
        if self.description:
            ret['description'] = self.description

        return ret

    def set_attributes(self, kwargs):
        require('name' in kwargs,
                fh.ServiceException(
                    "Request param 'name' must be set",
                    status='ERR_BAD_REQUEST'))
        require('aoi' in kwargs,
                fh.ServiceException(
                    "Request param 'aoi' must be set",
                    status='ERR_BAD_REQUEST'))

        self.name = kwargs['name']
        self.aoi_id = int(kwargs['aoi'])

        self.pri_region = int(kwargs['region-priority']) if 'region-priority' in kwargs else None
        self.pri_major = int(kwargs['major-priority']) if 'major-priority' in kwargs else None
        self.category_group = kwargs['category-group'] if 'category-group' in kwargs else None
        self.description = kwargs['description'] if 'description' in kwargs else None

    def get_ET_brief(self):
        return EM.diffalert_preset(
            self.name,
            id=self.id,
            created_by=self.created_by)

    def get_ET_full(self):
        ret = EM.diffalert_preset(
            EM.name(self.name),
            EM.aoi(self.aoi_id),
            id=self.id,
            created_by=self.created_by,
            created_at=self.created_at)
        if self.pri_region is not None:
            ret.append(EM.region_priority(self.pri_region))
        if self.pri_major is not None:
            ret.append(EM.major_priority(self.pri_major))
        if self.category_group:
            ret.append(EM.category_group(self.category_group))
        if self.description:
            ret.append(EM.description(self.description))
        return ret


@register_task_type(name=TASK_NAME)
class DiffAlert:
    @staticmethod
    def capabilities_ET():
        with db.get_write_session(db.VALIDATION_DB) as session:
            query = session.query(Preset).order_by(sa.desc('created_at'))
            presets = EM.diffalert_presets(*[preset.get_ET_brief() for preset in query.all()])

        branch_mgr = revision.BranchManager(get_pgpool(db.CORE_DB))

        branches = []
        branches += branch_mgr.branches_by_type(revision.BRANCH_TYPE_TRUNK)
        branches += branch_mgr.branches_by_type(revision.BRANCH_TYPE_STABLE)

        archive_branches = branch_mgr.branches_by_type(
            revision.BRANCH_TYPE_ARCHIVE,
            BRANCHES_LIMIT - len(branches))
        production_branch = archive_branches[0] if len(archive_branches) > 0 else None
        branches += archive_branches

        branches += branch_mgr.branches_by_type(
            revision.BRANCH_TYPE_DELETED,
            BRANCHES_LIMIT - len(branches))

        top_branches = branches[:BRANCHES_LIMIT]

        return EM.diffalert_task_type(
            presets,
            EM.branches(*[branch_ET(branch, production_branch) for branch in top_branches]))

    @staticmethod
    def create(uid, request):
        check_acl(uid)

        old_branch_id = int(request.values['base-branch'])
        new_branch_id = int(request.values['branch'])

        task = DiffAlertTask()
        task.on_create(uid)
        task.old_branch_id, task.old_commit_id = get_snapshot_id(old_branch_id)
        task.new_branch_id, task.new_commit_id = get_snapshot_id(new_branch_id)
        task.with_imported_objects = string_to_bool(request.values.get('with-imported-objects', 'false'))

        user_filter = request.values.get('user-filter', USER_FILTER_ALL)
        require(user_filter in ALLOWED_USER_FILTERS,
                fh.ServiceException('Unsupported user filter: %s' % user_filter,
                                    status='ERR_BAD_REQUEST'))
        task.user_filter = user_filter

        return task

    @staticmethod
    def launch(session, task_id, request):
        task = session.query(DiffAlertTask).get(task_id)

        args = {
            'type': TASK_NAME,
            'taskId': task.id,
            'oldBranchId': task.old_branch_id,
            'oldCommitId': task.old_commit_id,
            'newBranchId': task.new_branch_id,
            'newCommitId': task.new_commit_id,
            'withImportedObjects' : task.with_imported_objects,
            'userFilter' : task.user_filter
        }

        gateway = create_grinder_gateway()
        return gateway.submit(args)

    flask_blueprint = Blueprint(TASK_NAME, __name__)

    @flask_blueprint.route('/messages/mark-as-inspected', methods=['POST'])
    def handle_mark_as_inspected():
        uid = int(request.values['uid'])
        check_acl(uid)

        if 'message-ids' in request.values:
            message_ids = [int(id) for id in request.values['message-ids'].split(',')]
        else:
            category_groups = load_category_groups()
            filters = extract_filters(
                category_groups,
                dict((k.replace('-', '_'), v) for k, v in six.iteritems(request.values)))
            viewer = create_results_viewer(int(request.values['task-id']))
            message_ids = viewer.message_ids(**filters)

        messages = StoredMessage.mark_as_inspected(
            get_pgpool_validation_tasks(), message_ids, uid)
        return fh.xml_response(
            EM.tasks(
                EM.response_diffalert_mark_as_inspected(
                    *[message_ET(m) for m in messages] if 'message-ids' in request.values else [])))

    @flask_blueprint.route('/messages/<message_id>/postpone', methods=['POST'])
    def postpone_message(message_id):
        uid = int(request.values['uid'])
        check_acl(uid)

        action_str = request.values.get('action', 'postpone').upper()
        require(action_str in PostponeAction.names,
                fh.ServiceException('Unsupported action: %s' % action_str,
                                    status='ERR_BAD_REQUEST'))
        action = PostponeAction.names[action_str]

        message = StoredMessage.postpone(
            get_pgpool_validation_tasks(), int(message_id), action)
        return fh.xml_response(
            EM.tasks(
                EM.response_diffalert_postpone(message_ET(message))))

    @flask_blueprint.route('/messages/<message_id>/create-issue', methods=['POST'])
    @db.write_session('core')
    def create_issue(session, message_id):
        uid = int(request.values['uid'])
        check_acl(uid)

        page_url = request.values['page-url']

        with db.get_write_session(db.VALIDATION_DB) as validation_session:
            sql = sa.text(
                "SELECT task_id FROM diffalert.messages WHERE id = :id")
            result = validation_session.execute(sql, {'id' : int(message_id)})
            task_id = result.scalar()

        task = session.query(DiffAlertTask).get(task_id)
        require(task,
                fh.ServiceException(
                    'Task id %s is not found' % task_id,
                    status='ERR_BAD_REQUEST'))

        branch_id = task.new_branch_id

        viewer = create_results_viewer(int(task_id))
        stored_message = viewer.message(int(message_id))
        require(stored_message,
                fh.ServiceException(
                    'Message id %s is not found' % message_id,
                    status='ERR_BAD_REQUEST'))

        category_groups = load_category_groups()
        for group_id, category_ids in six.iteritems(category_groups):
            for category_id in category_ids:
                if category_id == stored_message.category_id:
                    message_category_group_id = group_id

        login = login_by_uid(uid)

        issue_creator = get_issue_creator()
        key = issue_creator.get_or_create_issue(
            get_pgpool_validation_tasks(),
            stored_message,
            branch_id,
            login,
            message_category_group_id,
            page_url.encode('utf-8'))

        return fh.xml_response(
            EM.tasks(
                EM.response_diffalert_create_issue(
                    EM.key(key))))

    @flask_blueprint.route('/presets', methods=['GET'])
    @db.write_session(db.VALIDATION_DB)
    def get_presets(session):
        int(request.values.get('uid', 0))  # parse uid
        page = int(request.values.get('page', 1))
        per_page = int(request.values.get('per-page', 10))

        query = session.query(Preset).order_by('created_at desc')

        page = fh.correct_page(page, per_page, query.count())
        offset = (page - 1) * per_page

        return fh.xml_response(
            EM.tasks(
                EM.response_diffalert_presets(
                    EM.diffalert_presets(
                        page=page,
                        per_page=per_page,
                        total_count=query.count(),
                        *[preset.get_ET_full() for preset in query[offset : offset+per_page]]))))

    @flask_blueprint.route('/presets', methods=['POST'])
    @db.write_session(db.VALIDATION_DB)
    def create_preset(session):
        preset = Preset()
        preset.set_attributes(request.values)
        preset.created_by = int(request.values['uid'])
        preset.created_at = utils.utcnow()

        session.add(preset)
        session.commit()

        return fh.xml_response(
            EM.tasks(
                EM.response_save_diffalert_preset(
                    preset.get_ET_full())))

    @flask_blueprint.route('/presets/<preset_id>', methods=['PUT'])
    @db.write_session(db.VALIDATION_DB)
    def change_preset(session, preset_id):
        uid = int(request.values['uid'])

        preset = session.query(Preset).get(preset_id)
        require(preset,
                fh.ServiceException(
                    'Preset id %s is not found' % preset_id,
                    status='ERR_MISSING_PRESET'))
        require(preset.created_by == uid,
                fh.ServiceException(
                    'User %s is not allowed to change preset %s' % (uid, preset_id),
                    status='ERR_FORBIDDEN'))

        preset.set_attributes(request.values)

        session.commit()

        return fh.xml_response(
            EM.tasks(
                EM.response_save_diffalert_preset(
                    preset.get_ET_full())))

    @flask_blueprint.route('/presets/<preset_id>', methods=['DELETE'])
    @db.write_session(db.VALIDATION_DB)
    def delete_preset(session, preset_id):
        uid = int(request.values['uid'])

        preset = session.query(Preset).get(preset_id)
        require(preset,
                fh.ServiceException(
                    'Preset id %s is not found' % preset_id,
                    status='ERR_MISSING_PRESET'))
        require(preset.created_by == uid,
                fh.ServiceException(
                    'User %s is not allowed to delete preset %s' % (uid, preset_id),
                    status='ERR_FORBIDDEN'))

        session.delete(preset)
        session.commit()

        return fh.xml_response(
            EM.tasks(
                EM.response_delete_diffalert_preset()))


class DiffAlertTask(Task):
    __tablename__ = 'diffalert_task'
    __table_args__ = {'schema': 'service'}
    __mapper_args__ = {'polymorphic_identity': 'diffalert'}

    id = sa.Column(sa.BigInteger, sa.ForeignKey('service.task.id'), primary_key=True)
    old_branch_id = sa.Column(sa.BigInteger)
    old_commit_id = sa.Column(sa.BigInteger)
    new_branch_id = sa.Column(sa.BigInteger)
    new_commit_id = sa.Column(sa.BigInteger)
    result_url = sa.Column(sa.Text)
    with_imported_objects = sa.Column(sa.Boolean)
    user_filter = sa.Column(sa.Text)

    def context_ET_brief(self, *args, **kwargs):
        return EM.diffalert_context(
            EM.base_branch(self.old_branch_id),
            EM.branch(self.new_branch_id),
            EM.with_imported_objects(self.with_imported_objects),
            EM.user_filter(self.user_filter))

    def result_ET_brief(self, *args, **kwargs):
        try:
            viewer = create_results_viewer(self.id)
            ret = EM.diffalert_result(
                EM.messages(total_count=viewer.message_count()))

            if self.result_url is not None:
                ret.append(EM.url(self.result_url))

            return ret
        except Exception as e:
            logging.exception(e)

    def result_ET_full(self, page=1, per_page=10, *args, **kwargs):
        category_groups = load_category_groups()
        group_by_cat = dict(
            (c, g) for g, cs in six.iteritems(category_groups) for c in cs)
        viewer = create_results_viewer(self.id)
        # branch_id = int(kwargs.get('branch', 0))

        sort_kind_str = kwargs.get('sort_kind', BY_NAME)
        require(sort_kind_str in [BY_SIZE, BY_NAME],
                fh.ServiceException('Wrong sort_kind %s' % sort_kind_str,
                                    status='ERR_BAD_REQUEST'))
        sort_kind = SortKind.BY_SIZE if sort_kind_str == BY_SIZE else SortKind.BY_NAME

        def dict_to_stats(field, stats_dict):
            return [Statistics(field, id=id, count=item['total_count'], inspected=item['inspected_count'])
                    for id, item in six.iteritems(stats_dict)]

        def stats_by_region_priority(filters):
            stats_dict = defaultdict(lambda: {'total_count': 0, 'inspected_count': 0})
            stats_filters = dict((k, v) for k, v in filters.items()
                                 if k in ('geom_wkb', 'postponed'))
            for item in viewer.statistics(**stats_filters):
                sum_item = stats_dict[item.region_priority]
                sum_item['total_count'] += item.total_count
                sum_item['inspected_count'] += item.inspected_count

            return dict_to_stats('region_priority', stats_dict)

        def stats_by_major_priority(filters):
            stats_dict = defaultdict(lambda: {'total_count': 0, 'inspected_count': 0})
            stats_filters = dict((k, v) for k, v in filters.items()
                                 if k in ('geom_wkb', 'region_priority', 'postponed'))
            for item in viewer.statistics(**stats_filters):
                sum_item = stats_dict[item.major_priority]
                sum_item['total_count'] += item.total_count
                sum_item['inspected_count'] += item.inspected_count

            return dict_to_stats('major_priority', stats_dict)

        def stats_by_cat_group(filters):
            stats_dict = defaultdict(lambda: defaultdict(lambda: {'total_count': 0, 'inspected_count': 0}))
            stats_filters = dict((k, v) for k, v in filters.items()
                                 if k in ('geom_wkb', 'major_priority', 'region_priority', 'postponed'))
            for item in viewer.statistics(**stats_filters):
                group = group_by_cat.get(item.category_id, 'other')
                sum_item = stats_dict[group][item.description]
                sum_item['total_count'] += item.total_count
                sum_item['inspected_count'] += item.inspected_count

            return [
                Statistics('category_group', id=id, inspected=None, children=dict_to_stats('description', desc_stats_dict))
                    for id, desc_stats_dict in six.iteritems(stats_dict)]

        def statistics_ET(stats_by_region_priority, stats_by_major_priority, stats_by_cat_group):
            return EM.statistics(
                EM.region_priorities(*[item.get_ET() for item in stats_by_region_priority]),
                EM.major_priorities(*[item.get_ET() for item in stats_by_major_priority]),
                EM.category_groups(*[item.get_ET() for item in stats_by_cat_group]))

        def messages_ET(page, per_page, filters):
            total_count = viewer.message_count(**filters)

            per_page = int(per_page)
            page = fh.correct_page(page, per_page, total_count)
            offset = (page - 1) * per_page

            messages = viewer.messages(sort_kind, offset, per_page, **filters)

            return EM.messages(
                total_count=total_count,
                page=page,
                per_page=per_page,
                *[message_ET(msg) for msg in messages])

        filters = extract_filters(category_groups, kwargs)
        ret = EM.diffalert_result(
            statistics_ET(
                stats_by_region_priority(filters),
                stats_by_major_priority(filters),
                stats_by_cat_group(filters)),
            messages_ET(page, per_page, filters))

        if self.result_url is not None:
            ret.append(EM.url(self.result_url))

        return ret
