# coding: utf-8
from __future__ import absolute_import, unicode_literals

import dataclasses
import json
import itertools
import logging
import collections

from django.conf import settings

from intranet.crt.constants import TASK_TYPE
from intranet.crt.exceptions import CrtTimestampError
from intranet.crt.tasks.base import CrtBaseTask
from intranet.crt.tags.serializers import NocCertSerializer
from intranet.crt.utils.file import SafetyWriteFile
from intranet.crt.utils.cvs import NocCvsClient
from intranet.crt.utils.tags import TagsDiffDict

log = logging.getLogger(__name__)
SerializedCertificate = collections.namedtuple('SerializedCertificate', ['serial', 'tags', 'common_name', 'ca_name', 'type'])


def get_threshold(cert_type, action):
    THRESHOLDS = settings.CRT_NOC_THRESHOLDS
    DEFAULT_THRESHOLDS = THRESHOLDS['default']
    return (
        THRESHOLDS
            .get(cert_type, DEFAULT_THRESHOLDS)
            .get(action, DEFAULT_THRESHOLDS[action])
    )


def are_diffs_equal(old_diff, new_diff):
    """Checks if diffs are equal

    Args:
        old_diff - dictionary loaded from file
        new_diff - diff from NocCertificatesDiff.diff, may contain OrderedDict's, and 'tags' is TagsDiffDict
    Returns:
        True if diffs are equal, False otherwise

    """
    def compare_dicts(old, new):
        if isinstance(old, dict) and isinstance(new, dict) and set(old.keys()) == set(new.keys()):
            keys = old.keys()
            return all(compare_dicts(old[k], new[k]) for k in keys)
        return old == new

    if 'certificates' not in old_diff or 'certificates' not in new_diff:
        return False
    if 'tags' not in old_diff or 'tags' not in new_diff:
        return False
    are_tags_equal = old_diff['tags'] == {
        'added': new_diff['tags'].added,
        'removed': new_diff['tags'].removed,
    }
    are_certificates_equal = compare_dicts(old_diff['certificates'], new_diff['certificates'])
    return are_tags_equal and are_certificates_equal


class NocCertificatesDiff:
    ActionTypeLimit = collections.namedtuple('ActionTypeLimit', ['action', 'cert_type', 'count', 'threshold'])

    def __init__(self, new_certificates, changed_tags, changed_certificates, removed_certificates):
        self.tags = changed_tags
        self.certificates = {
            'new': new_certificates,
            'changed': changed_certificates,
            'removed': removed_certificates,
        }
    
    @property
    def diff(self):
        return {
            'tags': self.tags,
            'certificates': self.certificates,
        }

    def _actions_by_types_counts_iter(self):
        for action, cert_changes in self.certificates.items():
            types_counter = collections.Counter(cert_change['type'] for cert_change in cert_changes.values())
            for cert_type, cert_type_count in types_counter.items():
                yield NocCertificatesDiff.ActionTypeLimit(
                    action=action,
                    cert_type=cert_type,
                    count=cert_type_count,
                    threshold=get_threshold(cert_type, action)
                )

    def is_safe(self):
        return all(
            action_type_limit.count <= action_type_limit.threshold
            for action_type_limit in self._actions_by_types_counts_iter()
        )

    def has_changes(self):
        return any(
            len(data) > 0
            for data in self.certificates.values()
        )

    def make_error_msg(self):
        errors = []
        limits_grouped_by_cert_type = itertools.groupby(self._actions_by_types_counts_iter(), lambda action_type_limit: action_type_limit.cert_type)
        for cert_type, action_type_limits in limits_grouped_by_cert_type:
            type_changes = [
                f'{limit.action}={limit.count}/{limit.threshold}'
                for limit in action_type_limits
            ]
            errors.append('{cert_type}: {type_changes}'.format(
                cert_type=cert_type,
                type_changes=','.join(type_changes),
            ))
        tags_repr = repr(self.tags)
        if tags_repr:
            errors.append(f"tag_changes: {tags_repr}")
        return ' | '.join(errors)


def compare_cert_sets(new_certs, old_certs):
    actual_certificates = {data['serial_number']: data for data in new_certs['certificates']}
    outdated_certificates = {data['serial_number']: data for data in old_certs['certificates']}

    new_serials = set(actual_certificates)
    old_serials = set(outdated_certificates)

    added_tags_counter = collections.Counter()
    removed_tags_counter = collections.Counter()
    changed_certificates = {}
    for serial in (new_serials & old_serials):
        new_cert_tags = set(actual_certificates[serial]['tags'])
        old_cert_tags = set(outdated_certificates[serial]['tags'])

        if new_cert_tags == old_cert_tags:
            continue

        cert_changed_data = dict.copy(actual_certificates[serial])
        cert_changed_data.pop('tags')
        if (added_tags := list(new_cert_tags - old_cert_tags)):
            cert_changed_data['added_tags'] = added_tags
            added_tags_counter += collections.Counter(added_tags)
        if (removed_tags := list(old_cert_tags - new_cert_tags)):
            cert_changed_data['removed_tags'] = removed_tags
            removed_tags_counter += collections.Counter(removed_tags)

        changed_certificates[serial] = cert_changed_data

    changed_tags = TagsDiffDict(
        added=dict(added_tags_counter),
        removed=dict(removed_tags_counter),
    )
    new_data = {
        serial: actual_certificates[serial]
        for serial in new_serials - old_serials
    }
    removed_data = {
        serial: outdated_certificates[serial]
        for serial in old_serials - new_serials
    }


    return NocCertificatesDiff(
        changed_tags=changed_tags,
        new_certificates=new_data,
        changed_certificates=changed_certificates,
        removed_certificates=removed_data,
    )


def is_diff_valid(cvs_client: NocCvsClient, new_certs) -> bool:
    """ Checks if diff is safe and has changes """
    old_certs = {
        'certificates': cvs_client.get_old_certs_with_types()
    }
    data_diff = compare_cert_sets(new_certs, old_certs)

    if not data_diff.is_safe():
        error_msg = data_diff.make_error_msg()
        diff_message = f'Update diff: {error_msg}'
        cvs_diff = cvs_client.get_diff()['diff']
        if not are_diffs_equal(cvs_diff, data_diff.diff):
            write_json_file(settings.CRT_CVS_DIFF_FILE, {
                'diff': data_diff.diff,
                'error_msg': error_msg,
            })
            cvs_client.commit(diff_message)
        raise CrtTimestampError(diff_message)

    if not data_diff.has_changes():
        log.info('Has no changes in tags. Skip cvs commit')
        return False
    return True


class JsonDiffEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, TagsDiffDict):
            return dataclasses.asdict(obj)
        return super().default(self, obj)


def write_json_file(cvs_filename, data):
    with SafetyWriteFile(cvs_filename) as noc_file:
        json.dump(data, noc_file, cls=JsonDiffEncoder, indent=4)


# Данная выгрузка работает только для удобства ручного просмотра выгрузки NOC'ами
def write_csv_for_noc(certs):
    with SafetyWriteFile(settings.CRT_CVS_CSV_FILE) as noc_file:
        noc_file.writelines(
            f'|{cert.serial}|{",".join(cert.tags)}|{cert.common_name}|{cert.ca_name}|\n'
            for cert in certs
        )


def write_json_for_noc(certs, meta):
    write_json_file(settings.CRT_CVS_JSON_FILE, {
        'certificates': [
            {
                'serial_number': cert.serial,
                'tags': cert.tags,
                'common_name': cert.common_name,
                'ca_name': cert.ca_name,
            }
            for cert in certs
        ],
        'meta': meta,
    })


def write_json_with_old_certs(certs):
    write_json_file(settings.CRT_CVS_OLD_DATA_FILE, [
        {
            'serial_number': cert.serial,
            'tags': cert.tags,
            'common_name': cert.common_name,
            'ca_name': cert.ca_name,
            'type': cert.type,
        }
        for cert in certs
    ])


def write_files(db_data):
    serialized_certs = [
        SerializedCertificate(
            serial=data['serial_number'],
            tags=data['tags'],
            common_name=data['common_name'],
            ca_name=data['ca_name'],
            type=data['type'],
        )
        for data in db_data['certificates']
    ]
    meta = db_data['meta']

    write_csv_for_noc(serialized_certs)
    write_json_for_noc(serialized_certs, meta)
    write_json_with_old_certs(serialized_certs)


class SyncCvsTagsTask(CrtBaseTask):
    task_type = TASK_TYPE.SYNC_CVS_TAGS
    lock_name = settings.CRT_SYNC_TAGS_LOCK_NAME

    def run(self, timestamp=None, force=False, commit_message=None, **kwargs):
        cvs_client = NocCvsClient()
        last_sync = timestamp.timestamp.start
        db_data = NocCertSerializer.from_db(last_sync).data
        cvs_client.up()

        if not force and not is_diff_valid(cvs_client, db_data):
            return
            
        write_files(db_data)
        cvs_client.commit(commit_message)
