from __future__ import absolute_import

from itertools import chain

from marshmallow import Schema, fields, pre_load, ValidationError, validates_schema
from marshmallow.validate import Range

from common.serialization.fields import DictNestedField


class DoSchema(Schema):
    action = fields.Raw(required=True)
    args = fields.List(fields.Raw())
    kwargs = fields.Dict()

    @pre_load
    def pre_load(self, data):
        if not isinstance(data, dict):
            data = {'action': data}

        return data


class StateSchema(Schema):
    do = fields.Nested(DoSchema)
    transitions = fields.Dict(missing=dict)


class ProcessSchemeSchema(Schema):
    states = DictNestedField(StateSchema, required=True)
    initial_state = fields.String(missing='initial')
    lock_update_interval = fields.Float(validate=Range(min=0.01), missing=5)
    lock_alive_time = fields.Float(validate=Range(min=0.01), missing=20)

    @validates_schema(skip_on_field_errors=True)
    def validates_schema(self, schema):
        can_transit_to = set(chain.from_iterable(
            s.get('transitions', {}).values() for s in schema['states'].values()
        ))
        can_transit_to.add(schema['initial_state'])
        states = set(schema['states'].keys())

        errors = []

        unreachable_states = states - can_transit_to
        if unreachable_states:
            errors.append(u'Some states are unreachable through transitions: {}'.format(unreachable_states))

        transit_to_undefined_states = can_transit_to - states
        if transit_to_undefined_states:
            errors.append(u'Some transitions lead to undefined states: {}'.format(transit_to_undefined_states))

        if errors:
            raise ValidationError(errors)

        return schema


def load_scheme(schema):
    return ProcessSchemeSchema(strict=True).load(schema).data
