# coding: utf-8
from collections import defaultdict
from decimal import Decimal
import logging
from itertools import chain
from typing import Dict, List, Optional, Tuple

from django import forms
from django.forms import fields
from django.core.exceptions import ValidationError
from django.db import transaction
from django.utils.translation import gettext as _

from review.gradient import models
from review.shortcuts import const
from review.lib import (
    views,
    serializers as lib_serializers,
    forms as lib_forms,
)
from review.core.logic import bulk
from review.core.logic import legacy


log = logging.getLogger(__name__)


def check_mode_builder(field):
    return lambda review: getattr(review, field) in (
        const.REVIEW_MODE.MODE_AUTO,
        const.REVIEW_MODE.MODE_MANUAL,
        const.REVIEW_MODE.MODE_MANUAL_BY_CHOSEN,
    )


class FileRowForm(forms.Form):

    def __init__(self, data, review, **kwargs):
        self.review = review
        filtered_data = {
            key: value
            for key, value in list(data.items())
            if value != const.DISABLED or key == 'id'
        }
        super(FileRowForm, self).__init__(
            data=filtered_data, **kwargs)


class PersonReviewImportValidateForm(FileRowForm):
    id = forms.IntegerField()

    mark_at_review = forms.CharField(
        required=False,
    )
    tag_average_mark = forms.CharField(
        required=False,
    )
    taken_in_average = lib_forms.NiceNullBooleanField(
        required=False,
    )
    salary_change_percentage = forms.DecimalField(
        required=False,
        min_value=const.VALIDATION.SALARY_CHANGE_MIN,
        max_value=const.VALIDATION.SALARY_CHANGE_MAX,
    )
    salary_change_absolute = forms.DecimalField(
        required=False,
        min_value=const.VALIDATION.SALARY_CHANGE_ABSOLUTE_MIN,
        max_value=const.VALIDATION.SALARY_CHANGE_ABSOLUTE_MAX,
    )
    grade_difference = forms.IntegerField(
        min_value=const.VALIDATION.LEVEL_CHANGE_MIN,
        max_value=const.VALIDATION.LEVEL_CHANGE_MAX,
        required=False
    )
    extra_payment = lib_forms.NiceNullBooleanField(
        required=False,
    )
    extra_option = lib_forms.NiceNullBooleanField(
        required=False,
    )
    bonus_payment_percentage = forms.DecimalField(
        required=False,
        min_value=const.VALIDATION.BONUS_MIN,
        max_value=const.VALIDATION.BONUS_MAX,
    )
    bonus = forms.DecimalField(
        required=False,
        min_value=const.VALIDATION.BONUS_ABSOLUTE_MIN,
        max_value=const.VALIDATION.BONUS_ABSOLUTE_MAX,
    )
    bonus_rsu = forms.IntegerField(
        min_value=const.VALIDATION.BONUS_RSU_MIN,
        max_value=const.VALIDATION.BONUS_RSU_MAX,
        required=False
    )
    deferred_payment = forms.DecimalField(
        min_value=const.VALIDATION.DEFERRED_PAYMENT_MIN,
        max_value=const.VALIDATION.DEFERRED_PAYMENT_MAX,
        required=False
    )
    bonus_option_value = forms.IntegerField(
        min_value=const.VALIDATION.OPTIONS_RSU_MIN,
        max_value=const.VALIDATION.OPTIONS_RSU_MAX,
        required=False
    )

    umbrella = forms.CharField(
        required=False,
    )
    main_product = forms.CharField(
        required=False,
    )

    field_to_require_check = {
        form_field: check_mode_builder(model_field)
        for form_field, model_field in (
            ('mark_at_review', 'mark_mode'),
            ('salary_change_percentage', 'salary_change_mode'),
            ('salary_change_absolute', 'salary_change_mode'),
            ('grade_difference', 'level_change_mode'),
            ('extra_payment', 'goldstar_mode'),
            ('extra_option', 'goldstar_mode'),
            ('bonus_payment_percentage', 'bonus_mode'),
            ('bonus', 'bonus_mode'),
            ('bonus_rsu', 'bonus_mode'),
            ('deferred_payment', 'deferred_payment_mode'),
            ('bonus_option_value', 'options_rsu_mode'),
        )
    }
    field_to_require_check['id'] = lambda obj: True
    field_to_require_check['tag_average_mark'] = lambda obj: False
    field_to_require_check['taken_in_average'] = lambda obj: False
    field_to_require_check['umbrella'] = lambda obj: False
    field_to_require_check['main_product'] = lambda obj: False

    def clean(self):
        cleaned = super(PersonReviewImportValidateForm, self).clean()
        for field_name, value in list(cleaned.items()):
            is_empty = value in fields.EMPTY_VALUES
            if self.review:
                required = self.field_to_require_check[field_name](self.review)
                if field_name == 'mark_at_review':
                    scale = getattr(self.review.scale, 'scale', {})
                    for mark, data in list(scale.items()):
                        if isinstance(data, dict):
                            if data['text_value'] == value:
                                cleaned[field_name] = mark
                                value = mark

                    if value == const.MARK.TEXT_NO_MARK:
                        cleaned[field_name] = const.MARK.NO_MARK
                        value = const.MARK.NO_MARK

                available_marks = chain(getattr(self.review.scale, 'scale', []), const.MARK.SPECIAL_VALUES)
                wrong_mark = value and field_name == 'mark_at_review' and value not in available_marks
                if wrong_mark:
                    err = _("Select a valid choice. That choice is not one of the available choices.")
                    self.add_error(field_name, ValidationError(err))
            else:
                required = False
            if is_empty and required and field_name in self.data:
                self.add_error(field_name, ValidationError(_('This field cannot be blank.')))
            if isinstance(value, Decimal):
                cleaned[field_name] = value.quantize(Decimal('1.00'))
        person_review_id = cleaned.pop('id')
        cleaned = {
            key: value for key, value
            in list(cleaned.items())
            if value not in fields.EMPTY_VALUES
        }
        return person_review_id, cleaned


class PersonReviewImportForm(forms.Form):
    data_file = forms.FileField()

    def clean(self):
        if not self.is_valid():
            raise ValidationError(self.errors)
        cleaned = super(PersonReviewImportForm, self).clean()
        serialized_file = lib_serializers.serialize_file_sheet(
            cleaned['data_file'],
            PersonReviewImportValidateForm,
        )
        self._set_gradient_models(serialized_file)
        return {
            'updates': {
                person_review_id: self.normalize_legacy_format(row)
                for person_review_id, row in serialized_file
            }
        }

    @classmethod
    def normalize_legacy_format(cls, row):
        result = {}
        for new_field, old_field in list(legacy.EXPORT_NEW_TO_OLD.items()):
            if old_field not in row:
                continue
            result[new_field] = row.pop(old_field)

        extra_payment, extra_option = row.pop('extra_payment', None), row.pop('extra_option', None)
        if None not in (extra_payment, extra_option):
            result[const.FIELDS.GOLDSTAR] = legacy.goldstar_old_to_new(
                extra_payment=extra_payment,
                extra_option=extra_option,
            )
        result.update(row)
        return result

    @classmethod
    def _set_gradient_models(cls, serialized: List[Tuple[int, Dict]]):
        gradient_info = cls._get_gradient_models(serialized)
        main_product_to_name, umbrella_to_mp_to_name = gradient_info

        for id_, row in serialized:
            umbrella = row.get('umbrella')
            if umbrella:
                mp_name, u_name = cls._get_names_for_umbrella(umbrella)
                model = umbrella_to_mp_to_name[u_name].get(mp_name)
                if model:
                    row['umbrella'] = model
                else:
                    log.warning(
                        'Not found umbrella %s for %s',
                        umbrella,
                        id_,
                    )
                    row.pop('umbrella', None)
            main_product = row.get('main_product')
            if main_product:
                model = main_product_to_name.get(main_product)
                if model:
                    row['main_product'] = model
                else:
                    log.warning(
                        'Not found main_product %s for %s',
                        main_product,
                        id_,
                    )
                    row.pop('main_product', None)

    @staticmethod
    def _get_names_for_umbrella(umbrella: str) -> Tuple[str, str]:
        if ' / ' in umbrella:
            res = umbrella.split(' / ', 1)
        else:
            res = None, umbrella
        return res

    @classmethod
    def _get_gradient_models(
        cls,
        serialized: List[Tuple[int, Dict]],
    ) -> Tuple[
        Dict[str, models.MainProduct],
        Dict[str, Dict[str, models.Umbrella]],
    ]:
        main_product_to_name = {}
        umbrella_to_mp_to_model = defaultdict(dict)
        for _, row in serialized:
            umbrella = row.get('umbrella')
            main_product = row.get('main_product')
            if umbrella:
                mp_name, u_name = cls._get_names_for_umbrella(umbrella)
                umbrella_to_mp_to_model[u_name][mp_name] = None
            if main_product:
                main_product_to_name[main_product] = None
        if not (main_product_to_name or umbrella_to_mp_to_model):
            return main_product_to_name, umbrella_to_mp_to_model

        umbrella_q = (
            models.Umbrella.objects
            .filter(name__in=umbrella_to_mp_to_model)
            .select_related('main_product')
            .order_by('id')
        )
        for umbrella in umbrella_q:
            mp_to_model = umbrella_to_mp_to_model[umbrella.name]
            mp_name = umbrella.main_product and umbrella.main_product.name
            if mp_name in mp_to_model:
                mp_to_model[mp_name] = umbrella

        main_product_q = (
            models.MainProduct.objects
            .filter(name__in=main_product_to_name)
            .order_by('id')
        )
        for main_product in main_product_q:
            main_product_to_name[main_product.name] = main_product

        return main_product_to_name, umbrella_to_mp_to_model


class PersonReviewImportView(views.View):
    form_cls_post = PersonReviewImportForm

    @transaction.atomic
    def process_post(self, auth, data):
        result = bulk.bulk_different_action_set(
            subject=auth.user,
            data=data['updates'],
            subject_type=const.PERSON_REVIEW_CHANGE_TYPE.FILE,
        )

        result = {
            pre.id: changes
            for pre, changes in list(result.items())
        }

        failed = [
            changes for changes in list(result.values())
            if (
                changes == const.NO_ACCESS or
                any(value != const.DISABLED for value in list(changes['failed'].values()))
            )
        ]
        if failed:
            transaction.set_rollback(True)

            return self.do_response(
                request=self.request,
                response=result,
                status_code=400,
            )
        else:
            return result
