import logging
import datetime
from requests.exceptions import HTTPError
from sqlalchemy.orm import Session

from watcher.config import settings
from watcher.db.base import dbconnect
from watcher.tasks.base import lock_task
from watcher.crud.notification import (
    query_unsend_notifications,
    get_shifts_for_notify_start_soon,
)
from watcher.crud.problem import (
    query_active_problems_without_notifications,
)
from watcher.logic.timezone import now
from watcher.logic.clients.jns import jns_client
from watcher.logic.notification import (
    get_notification_channels,
    get_start_shift_params,
    get_start_shift_soon_params,
    get_start_shift_request_id,
    get_start_shift_soon_request_id,
    get_problem_nobody_on_duty_request_id,
    get_problem_staff_has_gap_request_id,
    get_problem_nobody_on_duty_params,
    get_problem_staff_has_gap_params,
)
from watcher.logic.permissions import get_problems_responsibles
from watcher.db import Notification
from watcher import enums

logger = logging.getLogger(__name__)


@lock_task(save_metrics=True, send_to_unistat=True)
@dbconnect
def create_problem_notifications(session: Session):
    logger.info('Starting create problem notification task')
    current_now = now()

    notification_type = {
        enums.ProblemReason.nobody_on_duty: enums.NotificationType.problem_nobody_on_duty,
        enums.ProblemReason.staff_has_gap: enums.NotificationType.problem_staff_has_gap,
    }

    problems = query_active_problems_without_notifications(
        db=session, notification_type=notification_type, current_now=current_now
    ).all()
    responsibles = get_problems_responsibles(
        db=session, problems=problems,
    )

    processed_problems = 0
    to_create = list()
    for problem in problems:
        shift = problem.shift

        problem_responsibles = responsibles.get(problem)
        if not problem_responsibles:
            logger.warning(f'No responsible staff found for shift: {shift.id}')
            continue

        processed_problems += 1
        for staff_id in problem_responsibles:
            to_create.append({
                'staff_id': staff_id,
                'shift_id': shift.id,
                'send_at': current_now,
                'valid_to': shift.end,
                'type': notification_type[problem.reason],
            })
    if to_create:
        session.bulk_insert_mappings(Notification, to_create)
    logger.info('Finish create problem notification task. Processed {} problems. Skipped {}'.format(
        processed_problems, len(problems) - processed_problems
    ))


@lock_task(save_metrics=True, send_to_unistat=True)
@dbconnect
def create_shift_start_soon_notifications(session: Session):
    logger.info('Starting create start shift notification task')
    current_now = now()
    valid_to = current_now + datetime.timedelta(hours=settings.START_SOON_NOTIFY_VALID)
    shifts_data = get_shifts_for_notify_start_soon(db=session)
    logger.info(f'Creating {len(shifts_data)} notifications')
    if shifts_data:
        session.bulk_insert_mappings(
            Notification,
            [
                {
                    'valid_to': valid_to,
                    'shift_id': shift_data['shift_id'],
                    'type': enums.NotificationType.start_shift_soon,
                    'staff_id': shift_data['staff_id'],
                    'send_at': current_now,
                }
                for shift_data in shifts_data
            ]
        )
    logger.info('Finish create start shift notification task')


@lock_task(save_metrics=True, send_to_unistat=True)
@dbconnect
def send_notifications(session: Session):
    logger.info('Starting send_notifications task')
    current_now = now()
    sended = []
    outdated = []
    notifications = query_unsend_notifications(db=session).all()
    channels = get_notification_channels(
        db=session, staff_ids=[item.staff_id for item in notifications]
    )
    for notification in notifications:
        if notification.valid_to < current_now:
            logger.info(f'Notification {notification.id} outdated, skipping')
            outdated.append(notification.id)
            continue

        template = None
        params = None
        request_id = None

        if notification.type == enums.NotificationType.start_shift:
            template = settings.JNS_START_SHIFT_TEMPLATE
            params = get_start_shift_params(notification=notification)
            request_id = get_start_shift_request_id(notification=notification)

        elif notification.type == enums.NotificationType.start_shift_soon:
            template = settings.JNS_START_SHIFT_SOON_TEMPLATE
            params = get_start_shift_soon_params(notification=notification)
            request_id = get_start_shift_soon_request_id(notification=notification)

        elif notification.type == enums.NotificationType.problem_nobody_on_duty:
            template = settings.JNS_PROBLEM_NOBODY_ON_DUTY
            params = get_problem_nobody_on_duty_params(notification=notification)
            request_id = get_problem_nobody_on_duty_request_id(notification=notification)

        elif notification.type == enums.NotificationType.problem_staff_has_gap:
            template = settings.JNS_PROBLEM_STAFF_HAS_GAP
            params = get_problem_staff_has_gap_params(notification=notification)
            request_id = get_problem_staff_has_gap_request_id(notification=notification)

        if template:
            login = notification.staff.login
            channel = channels[notification.staff_id]
            if (
                channel == enums.JnsChannel.email
                and settings.ENV_TYPE != 'production'
            ):
                login = settings.ROBOT_LOGIN
            try:
                jns_client.send_message(
                    template=template,
                    channel=channel,
                    login=login,
                    params=params,
                    request_id=request_id,
                )
                sended.append(notification.id)
            except HTTPError:
                logger.exception(f'Got exception while sending: {notification.id}')
        else:
            logger.error(
                f'Unexpected notification type: {notification.type}, '
                f'id: {notification.id}'
            )
    for items, state in (
        (sended, enums.NotificationState.send),
        (outdated, enums.NotificationState.outdated),
    ):
        if items:
            session.query(Notification).filter(
                Notification.id.in_(items)
            ).update(
                {
                    Notification.state: state,
                    Notification.processed_at: now()
                },
                synchronize_session=False,
            )
    logger.info('Finish send_notifications task')
