import datetime
import logging
import waffle

from sqlalchemy import orm
from sqlalchemy.sql.functions import now

from infra.cauth.server.common.alchemy import Session
from infra.cauth.server.common.constants import CLIENT_SOURCE_REASON, SERVER_TYPE, IDM_STATUS
from infra.cauth.server.common.models import (
    Source,
    Server,
    ServerTrustedSourceRelation,
    ServerGroup,
    ServerGroupTrustedSourceRelation,
)

from infra.cauth.server.master.api.idm.update import ensure_remote_nodes_correct
from infra.cauth.server.master.api.models import IdmUpdate
from infra.cauth.server.master.utils.fqdn import should_be_pushed
from infra.cauth.server.master.utils.log import security_log_info
from infra.cauth.server.master.utils.tasks import task
from infra.cauth.server.master.constants import BATCH_OPERATION

log = logging.getLogger(__name__)

SECURITY_LOG_PATTERN = 'server={server},new_sources="{new_sources}",old_sources="{old_sources}"'


def is_working_day(date):
    return date.strftime('%w') not in ('0', '6')


def get_previous_working_day():
    date = datetime.date.today()
    while True:
        date -= datetime.timedelta(days=1)
        if is_working_day(date):
            return date


@task(use_lock=False, bind=True, dedicated_logger=False)
def push_idm_object(instance, dst, message=''):
    if waffle.switch_is_active('cauth.disable_idm_pushes'):
        log.info('push_idm_object(%s): switched off')
        return

    log.info('push_idm_object(%s): started', dst)
    server_query = Server.query.filter_by(fqdn=dst)
    group_query = ServerGroup.query.filter_by(name=dst)
    dst_obj = server_query.first() or group_query.first() or dst
    dst_exists = isinstance(dst_obj, (Server, ServerGroup))

    if dst_exists and dst_obj.idm_status == IDM_STATUS.ACTUAL:
        log.info('push_idm_object(%s) is not required for dsts with actual state', dst_obj)
        return

    try:
        was_pushed = ensure_remote_nodes_correct(dst_obj, message)
    except Exception:
        log.exception('push_idm_object(%s): ensure failed', dst)
        instance.retry(countdown=180, max_retries=5)

    commit = False

    if dst_exists and dst_obj.first_pending_push_started_at:
        dst_obj.last_push_ended_at = datetime.datetime.now()
        commit = True

    if not was_pushed and dst_exists and dst_obj.idm_status != IDM_STATUS.ACTUAL:
        # was_pushed=False означает, что dst не был запушен из-за того, что нечего было пушить (пустой diff)
        dst_obj.idm_status = IDM_STATUS.ACTUAL
        dst_obj.became_actual_at = datetime.datetime.now()
        log.info('push_idm_object(%s): marked as actual', dst)
        commit = True

    if commit:
        Session.commit()

    log.info('push_idm_object(%s): done', dst)


@task(use_lock=False, bind=True, dedicated_logger=False)
def push_idm_update(instance, dst):
    if waffle.switch_is_active('cauth.disable_idm_pushes'):
        log.info('push_idm_update(%s): switched off')
        return

    log.info('push_idm_update(%s): started', dst)
    query = IdmUpdate.query.filter_by(dst=dst, suite_is_finished=True)
    query = query.with_for_update().order_by(IdmUpdate.created_at.desc())
    updates = list(query)
    for i, update in enumerate(updates):
        if i:
            Session.delete(update)
        else:
            update.started_at = now()

    Session.commit()

    if updates:
        update = updates[0]
        Session.commit()
    else:
        log.info('push_idm_update(%s): update is irrelevant', dst)
        return

    dst_obj = update.get_object()
    dst_exists = isinstance(dst_obj, (Server, ServerGroup))
    if dst_exists and dst_obj.idm_status == IDM_STATUS.ACTUAL:
        log.info('push_idm_update(%s) deleting for dst with actual state', dst_obj)
        Session.delete(update)
        Session.commit()
        return

    try:
        was_pushed = ensure_remote_nodes_correct(dst_obj, BATCH_OPERATION.UPDATE_SUITE)
    except Exception:
        log.exception('push_idm_update(%s): ensure failed', dst)
        instance.retry(countdown=180, max_retries=5)

    if dst_exists and dst_obj.first_pending_push_started_at:
        dst_obj.last_push_ended_at = datetime.datetime.now()

    if not was_pushed and dst_exists and dst_obj.idm_status != IDM_STATUS.ACTUAL:
        # was_pushed=False означает, что dst не был запушен из-за того, что нечего было пушить (пустой diff)
        dst_obj.idm_status = IDM_STATUS.ACTUAL
        dst_obj.became_actual_at = datetime.datetime.now()
        log.info('push_idm_object(%s): marked as actual', dst)

    log.info('push_idm_update(%s) deleting as successfully processed', dst)
    Session.delete(update)
    Session.commit()


@task(dedicated_logger=False)
def cleanup_idm_updates():
    threshold = datetime.datetime.now() - datetime.timedelta(days=1)
    stale_updates = IdmUpdate.query.filter(IdmUpdate.created_at < threshold)
    stale_updates.delete(synchronize_session=False)
    Session.commit()


def get_servers_from_obj(obj):
    if isinstance(obj, Server):
        return [obj]
    if isinstance(obj, ServerGroup):
        return obj.servers
    raise RuntimeError('Invalid object: {}'.format(type(obj).__name__))


def update_sources(session, obj, sources_names, do_commit=True):
    sources_names = set(sources_names) - {'default'}

    new_db_sources = Session.query(Source).filter(Source.name.in_(sources_names)).all()
    new_sources = {source.name for source in new_db_sources}
    new_sources_ids = {source.name: source.id for source in new_db_sources}
    if len(sources_names) > len(new_sources):
        invalid_sources = sources_names - new_sources
        raise RuntimeError('Invalid sources: {}'.format(', '.join(invalid_sources)))

    if isinstance(obj, Server):
        relation_model = ServerTrustedSourceRelation
        obj_name = 'server'
        reason = CLIENT_SOURCE_REASON.FROM_CLIENT
    elif isinstance(obj, ServerGroup):
        relation_model = ServerGroupTrustedSourceRelation
        obj_name = 'servergroup'
        reason = CLIENT_SOURCE_REASON.FROM_SOURCE
    else:
        raise RuntimeError('Invalid object: {}'.format(type(obj).__name__))

    obj_id_attr = obj_name + '_id'

    relations = (
        session.query(relation_model)
        .filter(getattr(relation_model, obj_name) == obj)
        .options(
            orm.joinedload(getattr(relation_model, obj_name), innerjoin=True),
            orm.joinedload(relation_model.source, innerjoin=True),
        )
    )
    old_sources_relations = {relation.source.name: relation for relation in relations}
    old_sources = {relation.source.name for relation in relations}

    if new_sources == old_sources:
        return

    for source_name in new_sources - old_sources:
        relation_model.create(
            session=session,
            source_id=new_sources_ids[source_name],
            reason=reason,
            **{obj_id_attr: obj.id}
        )

    for source_name in old_sources - new_sources:
        old_sources_relations[source_name].delete(session)

    obj_servers = get_servers_from_obj(obj)
    for server in obj_servers:
        if server.idm_status == IDM_STATUS.ACTUAL:
            server.idm_status = IDM_STATUS.DIRTY
            server.became_dirty_at = datetime.datetime.now()

    if do_commit:
        session.commit()

    for server in obj_servers:
        security_log_info(log, SECURITY_LOG_PATTERN.format(
            server=server.fqdn,
            new_sources=' '.join(sorted(new_sources)),
            old_sources=' '.join(sorted(old_sources)),
        ))
        if should_be_pushed(server.fqdn):
            push_idm_object.delay(server.fqdn, BATCH_OPERATION.UPDATE_SOURCES)


@task(dedicated_logger=False, use_lock=False)
def update_server_sources(server_fqdn, server_type, sources_names, client_version):
    server = Session.query(Server).filter(Server.fqdn == server_fqdn).first()
    if server is None:
        raise RuntimeError('Invalid server fqdn: {}'.format(server_fqdn))

    if server_type is not None and server_type != server.type:
        choices = SERVER_TYPE.choices()
        if server_type not in choices:
            raise RuntimeError(
                'Invalid server type {}. Available: {}'.format(server_type, ','.join(choices))
            )
        server.type = server_type

    if sources_names is not None:
        update_sources(Session, server, sources_names)

    if client_version is not None and server.client_version != client_version:
        server.client_version = client_version

    Session.commit()
