import logging

from django.http import HttpResponseBadRequest, HttpResponseForbidden

from infra.cauth.server.common.alchemy import Session
from infra.cauth.server.common.constants import IDM_STATUS, FLOW_TYPE
from infra.cauth.server.common.models import Server, ServerGroup, Source, ServerResponsible

from infra.cauth.server.master.api.views.base import BaseView, FormError
from infra.cauth.server.master.api.forms import RemoveServerForm, AddServerForm
from infra.cauth.server.master.api.tasks import push_idm_object, update_sources
from infra.cauth.server.master.utils.database import get_or_create_many, CreateRequest
from infra.cauth.server.master.utils.dns_status import create_or_update_dns_status
from infra.cauth.server.master.utils.fqdn import should_be_pushed
from infra.cauth.server.master.constants import BATCH_OPERATION

import datetime

logger = logging.getLogger(__name__)


class BaseServerView(BaseView):
    REQUIRE_CERT = 'servers'

    def __init__(self):
        super(BaseServerView, self).__init__()

        self.server = None
        self.all_groups = []

    def make_idm_updates(self, message=''):
        if should_be_pushed(self.server.fqdn):
            if self.server.idm_status == IDM_STATUS.ACTUAL:
                self.server.idm_status = IDM_STATUS.DIRTY
                self.server.became_dirty_at = datetime.datetime.now()
            if (
                self.server.first_pending_push_started_at is None
                or (
                    self.server.last_push_ended_at
                    and self.server.last_push_ended_at > self.server.first_pending_push_started_at
                )
            ):
                self.server.first_pending_push_started_at = datetime.datetime.now()

        groups_to_push = [
            group for group in self.all_groups
            if any(should_be_pushed(server.fqdn) for server in group.servers)
        ]

        for group in groups_to_push:
            if group.idm_status == IDM_STATUS.ACTUAL:
                group.idm_status = IDM_STATUS.DIRTY
                group.became_dirty_at = datetime.datetime.now()
            if (
                group.first_pending_push_started_at is None
                or (
                    group.last_push_ended_at
                    and group.last_push_ended_at > group.first_pending_push_started_at
                )
            ):
                group.first_pending_push_started_at = datetime.datetime.now()
        Session.commit()

        if should_be_pushed(self.server.fqdn):
            push_idm_object.delay(self.server.fqdn, message=message)
        for group in groups_to_push:
            push_idm_object.delay(group.name, message=message + BATCH_OPERATION.RUN_FOR_GROUP)


class AddServerView(BaseServerView):
    def __init__(self):
        super(AddServerView, self).__init__()

        self.server_fqdn = None
        self.group_names = None
        self.responsibles = None
        self.server_type = None
        self.flow = None
        self.trusted_sources = None
        self.key_sources = None
        self.secure_ca_list_url = None
        self.insecure_ca_list_url = None
        self.krl_url = None
        self.sudo_ca_list_url = None

        self.created_groups = []

    def clean_data(self):
        form = AddServerForm(self.params)
        if not form.is_valid():
            raise FormError(form)

        self.group_names = set([grp.lower() for grp in form.cleaned_data['grp']])
        if form.cleaned_data['resp'] is not None:
            self.responsibles = set(form.cleaned_data['resp'])
        self.server_fqdn = form.cleaned_data['srv'].lower()
        self.server_type = form.cleaned_data.get('type')
        self.flow = form.cleaned_data['flow']
        self.trusted_sources = form.cleaned_data.get('trusted_sources')
        self.trusted_sources.add(self.source.name)

        self.key_sources = form.cleaned_data.get('key_sources')
        self.secure_ca_list_url = form.cleaned_data.get('secure_ca_list_url')
        self.insecure_ca_list_url = form.cleaned_data.get('insecure_ca_list_url')
        self.krl_url = form.cleaned_data.get('krl_url')
        self.sudo_ca_list_url = form.cleaned_data.get('sudo_ca_list_url')

    def get_or_create_objects(self):
        create_requests = [
            CreateRequest(Server, 'fqdn', [self.server_fqdn]),
        ]

        if self.group_names:
            create_requests.append(
                CreateRequest(ServerGroup, 'name', self.group_names)
            )

        results = get_or_create_many(Session, create_requests, for_update=True)

        for result in results:
            if isinstance(result.obj, Server):
                self.server = result.obj
                if result.is_created:
                    self.server.sources = self.sources
                    if self.server_type:
                        self.server.type = self.server_type
                else:
                    if self.source and not self.source.is_default:
                        self.server.sources.extend(self.sources)
            elif isinstance(result.obj, ServerGroup):
                self.all_groups.append(result.obj)
                if result.is_created:
                    self.created_groups.append(result.obj)
                    result.obj.source = self.source
            else:
                raise TypeError('Unexpected obj type in result: %s' % type(result.obj).__name__)

        if self.server is None:
            raise RuntimeError('Server not found in creation results')

    def update_responsibles(self):
        existing_server_resps = {r.user for r in self.server.responsibles if r.source == self.source}

        responsibles_to_remove = existing_server_resps - self.responsibles

        for sr in self.server.responsibles[:]:
            if sr.user in responsibles_to_remove and sr.server == self.server and sr.source == self.source:
                self.server.responsibles.remove(sr)

        for user in self.responsibles - existing_server_resps:
            self.server.responsibles.append(ServerResponsible(
                user=user,
                server=self.server,
                source=self.source,
            ))

        for group in self.created_groups:
            group.responsible_users = list(self.responsibles)

    def update_membership(self):
        if self.source.is_modern:
            groups_to_remove = {
                group for group in self.server.groups
                if group.source == self.source and group not in self.all_groups
            }
            for group in groups_to_remove:
                self.server.groups.remove(group)
                Session.add(group)

        for group in set(self.all_groups) - set(self.server.groups):
            self.server.groups.append(group)

    def post(self, request):
        if request.source_name and request.should_set_source:
            self.source = Source.query.filter_by(name=request.source_name).first()
        else:
            self.source = Source.query.filter_by(is_default=True).first()
        self.sources = [self.source]

        self.clean_data()

        if self.source.is_modern:
            if len(self.group_names) != 1:
                return HttpResponseBadRequest(
                    content="'grp' parameter for source '{}' must contain exact one name.".format(
                        self.source.name
                    ),
                    content_type='text/plain; charset=utf-8',
                )
        else:
            if self.flow:
                return HttpResponseBadRequest(
                    content="Source '{source}' is not allowed to set flow".format(
                        source=self.source.name,
                    ),
                    content_type='text/plain; charset=utf-8',
                )
        if self.group_names and not self.source.is_default and not all(
                name.split('.')[0] == self.source.name for name in self.group_names
        ):
            return HttpResponseBadRequest(
                content="Group name prefix must equal '{source}'".format(
                    source=self.source.name,
                ),
                content_type='text/plain; charset=utf-8',
            )
        try:
            self.get_or_create_objects()
            self.update_membership()
            self.server.owner_id = self.server.owner_id or self.source.id
            if self.source.is_modern:
                if not self.check_authoritative_permission(self.server, self.source):
                    Session.rollback()
                    return HttpResponseForbidden(
                        content="Server '{server}' already has authoritative source".format(
                            server=self.server,
                        ),
                        content_type='text/plain; charset=utf-8',
                    )
                if self.flow:
                    try:
                        self.server.set_flow(self.flow)
                    except RuntimeError as e:
                        Session.rollback()
                        return HttpResponseBadRequest(
                            content=str(e),
                            content_type='text/plain; charset=utf-8',
                        )
                if len(self.key_sources) > 0:
                    self.server.key_sources = ",".join(sorted(self.key_sources))
                if self.secure_ca_list_url is not None:
                    self.server.secure_ca_list_url = self.secure_ca_list_url
                if self.insecure_ca_list_url is not None:
                    self.server.insecure_ca_list_url = self.insecure_ca_list_url
                if self.krl_url is not None:
                    self.server.krl_url = self.krl_url
                if self.sudo_ca_list_url is not None:
                    self.server.sudo_ca_list_url = self.sudo_ca_list_url

            if self.responsibles is not None:
                self.update_responsibles()

            create_or_update_dns_status(Session, self.server)
        except Exception as e:
            logger.exception('Add server {server} failed. Request params {params}: Exception: {exc}'.format(
                server=self.server,
                params=request.params,
                exc=str(e),
            ))
            Session.rollback()
            raise

        Session.commit()

        self.make_idm_updates(message=BATCH_OPERATION.ADD_SERVER)

        if self.flow == FLOW_TYPE.BACKEND_SOURCES:
            update_sources(Session, self.server.authoritative_group, self.trusted_sources)

        return {
            'status': 'added',
            'srv': self.server.fqdn,
        }


class RemoveServerView(BaseServerView):
    def post(self, request):
        form = RemoveServerForm(self.params)
        if not form.is_valid():
            raise FormError(form)

        self.server = Session.merge(form.cleaned_data['srv'])
        source = None
        if request.source_name is not None:
            source = Source.query.filter_by(name=request.source_name).first()
        if not self.check_authoritative_permission(self.server, source):
            return HttpResponseForbidden(
                content="Host is owned by other source",
                content_type='text/plain; charset=utf-8',
            )

        self.all_groups = self.server.groups
        Session.delete(self.server)
        Session.commit()

        self.make_idm_updates(message=BATCH_OPERATION.REMOVE_SERVER)

        return {
            'status': 'removed',
            'srv': self.server.fqdn,
        }
