import base64
import dataclasses
import json
import os
from collections import OrderedDict, defaultdict
from typing import Callable, Optional, List

from django.core.serializers import base
from django.core.serializers.json import Serializer
from django.core.serializers.python import _get_model
from django.db import DEFAULT_DB_ALIAS, models
from django.db import transaction
from django.utils.encoding import force_text
from wiki.grids.models import Revision
from wiki.org import org_ctx
from wiki.pages.models import Page


class FixtureProvider:
    def __init__(self):
        self.collection = defaultdict(dict)

    def register(self, instance, pk):
        self.collection[type(instance).__name__][pk] = instance

    def lookup(self, Model, pk):
        return self.collection[Model.__name__][pk]


@dataclasses.dataclass
class RestorationOptions:
    on_success: Optional[Callable] = None
    ignore_fields: Optional[List] = dataclasses.field(default_factory=list)


def restore_model(d, fp: FixtureProvider, drop_pk=True, ignore=False, after_save=None, options=None):
    options = options or RestorationOptions()

    db = DEFAULT_DB_ALIAS
    field_names_cache = {}  # Model: <list of field_names>

    # Look up the model and starting build a dict of data for it.
    try:
        Model = _get_model(d['model'])
    except base.DeserializationError:
        if ignore:
            return None
        else:
            raise
    data = {}
    original_pk = None
    if 'pk' in d:
        try:
            original_pk = Model._meta.pk.to_python(d.get('pk'))
            if not drop_pk:
                data[Model._meta.pk.attname] = Model._meta.pk.to_python(d.get('pk'))
        except Exception as e:
            raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), None)
    m2m_data = {}

    if Model not in field_names_cache:
        field_names_cache[Model] = {f.name for f in Model._meta.get_fields()}
    field_names = field_names_cache[Model]

    for (field_name, field_value) in d['fields'].items():
        if field_name in options.ignore_fields:
            continue
        if ignore and field_name not in field_names:
            # skip fields no longer on model
            continue

        if isinstance(field_value, str):
            field_value = force_text(field_value, strings_only=True)

        field = Model._meta.get_field(field_name)

        # Handle M2M relations
        if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel):
            model = field.remote_field.model

            try:
                m2m_data[field.name] = []
                for pk in field_value:
                    m2m_data[field.name].append(fp.lookup(model, pk).pk)
            except Exception as e:
                raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), pk)

        # Handle FK fields
        elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel):
            model = field.remote_field.model
            if field_value is not None:
                try:
                    data[field.attname] = fp.lookup(model, field_value).pk
                except Exception as e:
                    raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), field_value)
            else:
                data[field.attname] = None

        # Handle all other fields
        else:
            try:
                data[field.name] = field.to_python(field_value)
            except Exception as e:
                raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), field_value)

    obj = base.build_instance(Model, data, db)
    # Handle each field
    for (field_name, field_value) in d['extra_fields'].items():
        setattr(obj, field_name, field_value)

    obj.save()
    if options.on_success:
        options.on_success(obj)

    fp.register(obj, original_pk)
    return obj


class WikiJsonSerializer(Serializer):
    def get_dump_object(self, obj):
        data = OrderedDict([('model', force_text(obj._meta))])
        if not self.use_natural_primary_keys or not hasattr(obj, 'natural_key'):
            data['pk'] = force_text(obj._get_pk_val(), strings_only=True)
        data['fields'] = self._current
        data['extra_fields'] = self._extra_fields
        return data

    def serialize(self, queryset, **options):
        """
        Serialize a queryset.
        """
        options['ensure_ascii'] = False
        self.options = options

        self.stream = options.pop('stream', self.stream_class())
        self.selected_fields = options.pop('fields', None)
        self.extra_fields = options.pop('extra_fields', [])
        self.extra_dyn_fields = options.pop('extra_dyn_fields', {})
        self._extra_fields = {}
        self.use_natural_foreign_keys = options.pop('use_natural_foreign_keys', False)
        self.use_natural_primary_keys = options.pop('use_natural_primary_keys', False)
        progress_bar = self.progress_class(options.pop('progress_output', None), options.pop('object_count', 0))

        self.start_serialization()
        self.first = True
        for count, obj in enumerate(queryset, start=1):
            self.start_object(obj)
            # Use the concrete parent class' _meta instead of the object's _meta
            # This is to avoid local_fields problems for proxy models. Refs #17717.
            concrete_model = obj._meta.concrete_model

            for field_name in self.extra_fields:
                self._extra_fields[field_name] = getattr(obj, field_name)

            for field_name, fn in self.extra_dyn_fields.items():
                self._extra_fields[field_name] = fn(obj)

            for field in concrete_model._meta.local_fields:
                if field.serialize:
                    if field.remote_field is None:
                        if self.selected_fields is None or field.attname in self.selected_fields:
                            self.handle_field(obj, field)
                    else:
                        if self.selected_fields is None or field.attname[:-3] in self.selected_fields:
                            self.handle_fk_field(obj, field)
            for field in concrete_model._meta.many_to_many:
                if field.serialize:
                    if self.selected_fields is None or field.attname in self.selected_fields:
                        self.handle_m2m_field(obj, field)
            self.end_object(obj)
            progress_bar.update(count)
            if self.first:
                self.first = False
        self.end_serialization()
        return self.getvalue()

    def handle_m2m_field(self, obj, field):
        if field.remote_field.through._meta.auto_created:
            if self.use_natural_foreign_keys and hasattr(field.remote_field.model, 'natural_key'):

                def m2m_value(value):
                    return value.natural_key()

            else:

                def m2m_value(value):
                    return force_text(value._get_pk_val(), strings_only=True)

            self._current[field.name] = [m2m_value(related) for related in getattr(obj, field.name).all()]


org = 248961


def b64file(file):
    return base64.b64encode(file.mds_storage_id.file.read()).decode()


def dump():
    pages = list(Page.objects.filter(org_id=org, status=1, supertag__startswith='hermiona').all())
    ws = WikiJsonSerializer()
    data = ws.serialize(pages, extra_fields=['body'])
    with open('./hermiona.json', 'w') as out_file:
        out_file.write(data)


def bake_body(page):
    Revision.objects.create_from_page(page)
    page.authors.add(page.owner)


@transaction.atomic
def provision_org(user, org):
    fp = FixtureProvider()
    org = org or user.orgs.all()[0]

    user_ids = [1115681, 1778075, 1115687, 1830073]
    for user_id in user_ids:
        fp.register(user, user_id)
        fp.register(user.staff, user_id)

    fp.register(org, 248961)

    with org_ctx(org):
        for file, options in [
            ('pages.json', RestorationOptions(on_success=bake_body, ignore_fields=['mds_storage_id'])),
            ('acl.json', None),
        ]:

            with open(os.path.join('/stage_provision', file)) as i:
                objs = json.load(i)
                for obj in objs:
                    restore_model(obj, fp, options=options)
