# -*- coding: utf-8 -*-
from collections import namedtuple
import logging

from passport.backend.core.tracks.differ import differ
from passport.backend.core.tracks.exceptions import ConcurrentTrackOperationError
from passport.backend.core.tracks.reader import decode_redis_data
from passport.backend.core.tracks.utils import (
    make_redis_key,
    make_redis_subkey,
)
from passport.backend.utils.string import smart_text
from six import iteritems


MAX_LIST_LENGTH = 1000

# Фейковый элемент, пишущийся в начало каждого списка
LIST_ROOT = '<root>'


log = logging.getLogger('passport.tracks.Serializer')


SerializationParams = namedtuple('SerializationParams', 'old_track new_track diff track_key version_key is_create')


class Serializer(object):
    @staticmethod
    def _check_protected_fields_not_changed_since_read(pipe, params):
        # Проверяем, не изменился ли old_track с момента его чтения
        actual_field_values = dict(
            decode_redis_data(pipe.parent.hgetall(params.track_key)),
            track_version=pipe.parent.get(params.version_key) or 0,
        )
        for field in list(params.old_track.concurrent_protected_fields) + ['track_version']:
            if field == 'track_version':
                track_value = smart_text(params.old_track.track_version)
            else:
                track_value = smart_text(params.old_track._data.get(field))
            actual_value = smart_text(actual_field_values.get(field))
            if actual_value != track_value:
                log.warning(
                    'Track %s has changed: %s %s != %s',
                    params.old_track.track_id,
                    field,
                    actual_value,
                    track_value,
                )
                raise ConcurrentTrackOperationError()

    @staticmethod
    def _create_or_update_track_version(pipe, params):
        if params.is_create:
            pipe.set(params.version_key, params.new_track.track_version)
            pipe.expire(params.version_key, params.new_track.ttl)
        else:
            changed_fields = set()
            if params.diff.data_diff:
                changed_fields = set(
                    list(params.diff.data_diff.added.keys()) +
                    list(params.diff.data_diff.changed.keys()) +
                    list(params.diff.data_diff.deleted.keys()),
                )
            if changed_fields.intersection(params.new_track.concurrent_protected_fields):
                pipe.incr(params.version_key)
                params.new_track.track_version += 1  # чтобы после транзакции модель трека оставалась в согласованном состоянии
                if not (params.old_track and params.old_track.track_version):
                    # TODO: выпилить эту ветку через 3 часа после выкатки PASSP-19966 в прод
                    pipe.expire(params.version_key, params.new_track.ttl)

    @staticmethod
    def _create_or_update_track_data(pipe, params):
        if params.diff.data_diff.added or params.diff.data_diff.changed:
            pipe.hmset(
                params.track_key,
                dict(params.diff.data_diff.added, **params.diff.data_diff.changed),
            )
        if params.diff.data_diff.deleted:
            pipe.hdel(
                params.track_key,
                *sorted(params.diff.data_diff.deleted.keys())
            )
        if params.is_create:
            pipe.expire(params.track_key, params.new_track.ttl)

    @staticmethod
    def _create_or_update_track_counters(pipe, params):
        for counter_name, delta in sorted(iteritems(params.diff.counter_diff.changed)):
            if delta < 0:
                # Значит, счётчик сбрасывали. Сразу выставим конечное значение.
                pipe.hset(params.track_key, counter_name, params.new_track._counters[counter_name])
            else:
                pipe.hincrby(params.track_key, counter_name, delta)

        if params.diff.counter_diff.deleted:
            pipe.hdel(
                params.track_key,
                *sorted(params.diff.counter_diff.deleted.keys())
            )

    @staticmethod
    def _create_or_update_track_lists(pipe, params):
        for list_name, _ in sorted(iteritems(params.diff.list_diff.deleted)):
            list_key = make_redis_subkey(params.old_track.track_id, list_name)
            # Не удаляем первый элемент LIST_ROOT
            pipe.ltrim(list_key, 0, 0)

        for list_name, items in sorted(iteritems(params.diff.list_diff.added)):
            list_key = make_redis_subkey(params.new_track.track_id, list_name)
            if params.is_create:
                items = [LIST_ROOT] + items

            if items:
                pipe.rpush(list_key, *items)
                length = len(params.new_track._lists[list_name])
                if length > MAX_LIST_LENGTH:
                    pipe.ltrim(list_key, length - MAX_LIST_LENGTH, -1)
                    pipe.lpush(list_key, LIST_ROOT)
                    params.new_track._lists[list_name] = params.new_track._lists[list_name][-MAX_LIST_LENGTH:]

            if params.is_create:
                pipe.expire(list_key, params.new_track.ttl)

    @staticmethod
    def create_or_change(old_track, new_track, pipe, allow_incremental_updates=True):
        params = SerializationParams(
            old_track=old_track,
            new_track=new_track,
            diff=differ(old_track, new_track),
            is_create=old_track is None,
            track_key=make_redis_key(new_track.track_id),
            version_key=make_redis_subkey(new_track.track_id, 'version'),
        )

        if not allow_incremental_updates and (params.diff.counter_diff or params.diff.list_diff):
            raise RuntimeError('Incremental track updates are forbidden (are you in nested transaction?)')

        # Если allow_incremental_updates=False, значит мы во вложенной транзакции. Тогда бессмысленно следить
        # за изменениями old_track - они могли случиться во внешней транзакции.
        if not params.is_create and allow_incremental_updates:
            pipe.watch(params.version_key)
            try:
                Serializer._check_protected_fields_not_changed_since_read(pipe, params)
            except:
                pipe.unwatch()
                raise

        pipe.multi()
        try:
            Serializer._create_or_update_track_version(pipe, params)
            Serializer._create_or_update_track_data(pipe, params)
            Serializer._create_or_update_track_counters(pipe, params)
            Serializer._create_or_update_track_lists(pipe, params)
        except:
            pipe.discard()
            raise

    @staticmethod
    def delete(old_track, pipe):
        pipe.multi()
        try:
            key = make_redis_key(old_track.track_id)
            pipe.delete(key)

            version_key = make_redis_subkey(old_track.track_id, 'version')
            pipe.delete(version_key)

            for list_name in sorted(old_track.list_names):
                pipe.delete(make_redis_subkey(old_track.track_id, list_name))
        except:
            pipe.discard()
            raise

    @staticmethod
    def execute(old_track, new_track, redis_node, allow_incremental_updates=True):
        pipe = redis_node.pipeline(readonly=False)
        if new_track is None:
            Serializer.delete(old_track, pipe)
        else:
            Serializer.create_or_change(old_track, new_track, pipe, allow_incremental_updates)

        return pipe.execute()
