from datetime import date
from collections import defaultdict
from itertools import product, groupby, chain
from operator import attrgetter, itemgetter

from django.conf import settings
from django.db.models import Sum

from staff.person.models import Staff
from staff.departments.models import Department
from staff.map.models import City, Country

from staff.departments.utils import get_department_path
from staff.lib.utils.date import get_date_setting

from .forms import COUNTRIES_BY, GROUP_BY
from .models import LentaLog, ACTION_CHOICES, CountAndGrowth


class Archaeologist(object):
    COMPANY_FOUNDED = date(1990, 12, 1)  # iseg.join_at
    DEP_ACTIONS = (ACTION_CHOICES.HIRED,
                   ACTION_CHOICES.TRANSFERRED,
                   ACTION_CHOICES.TRANSFERRED_AND_CHANGED_POSITION)

    def __init__(self, person):
        self.person = person
        self.qs = self.person.lentalog_set.order_by('created_at', 'id')

    def guess_join_date(self):
        joined = self.person.join_at
        created = self.person.created_at
        if self.COMPANY_FOUNDED <= joined < created.date():
            return joined
        else:
            return created

    def guess_quit_date(self):
        return self.person.quit_at

    def _guess_first(self, field):
        for log in self.qs:
            if getattr(log, field + '_old', None):
                return getattr(log, field + '_old')
        return getattr(self.person, field)

    def _guess_last(self, field):
        for log in reversed(self.qs):
            if getattr(log, field + '_new', None):
                return getattr(log, field + '_new')
        return getattr(self.person, field)

    def guess_hirement_log(self):
        return LentaLog(
            action=ACTION_CHOICES.HIRED,
            staff=self.person,
            created_at=self.guess_join_date(),
            position_old='',
            position_new=self.guess_first_position(),
            department_old=None,
            department_new=self.guess_first_department(),
            department_new_path=get_department_path(
                self.guess_last_department()),
            office_old=None,
            office_new=self.guess_first_office(),
            organization_old=None,
            organization_new=self.guess_first_organization()
        )

    def guess_dismissal_log(self):
        return LentaLog(
            action=ACTION_CHOICES.DISMISSED,
            staff=self.person,
            created_at=self.guess_quit_date(),
            position_old=self.person.position,
            position_new=self.guess_last_position(),
            department_old=self.person.department,
            department_new=self.guess_last_department(),
            department_new_path=get_department_path(
                self.guess_last_department()),
            office_old=self.person.office,
            office_new=self.guess_last_office(),
            organization_old=self.person.organization,
            organization_new=self.guess_last_organization()
        )

    def create_hirement_log(self):
        hl = self.guess_hirement_log()
        created_at = hl.created_at
        hl.save()
        hl.created_at = created_at
        hl.save()
        return hl

    def create_dismissal_log(self):
        dl = self.guess_dismissal_log()
        created_at = dl.created_at
        dl.save()
        dl.created_at = created_at
        dl.save()
        return dl

    def guess_first_department(self):
        return self._guess_first('department')

    def guess_last_department(self):
        return self._guess_last('department')

    def guess_first_office(self):
        return self._guess_first('office')

    def guess_last_office(self):
        return self._guess_last('office')

    def guess_first_organization(self):
        return self._guess_first('organization')

    def guess_last_organization(self):
        return self._guess_last('organization')

    def guess_first_position(self):
        return self._guess_first('position')

    def guess_last_position(self):
        return self._guess_last('position')

    def is_hirement_missing(self):
        return not self.qs or self.qs[0].action != ACTION_CHOICES.HIRED

    def is_dismissal_missing(self):
        return (
            self.person.quit_at and
            ACTION_CHOICES.DISMISSED not in list(map(attrgetter('action'), self.qs))
        )

    @classmethod
    def fill_missing_hirements(cls):
        for person in Staff.objects.all():
            arch = cls(person)
            if arch.is_hirement_missing():
                yield arch.create_hirement_log()

    @classmethod
    def fill_missing_dismissals(cls):
        for person in Staff.objects.all():
            arch = cls(person)
            if arch.is_dismissal_missing():
                yield arch.create_dismissal_log()

    def employments(self):
        logs = self.qs.iterator()

        try:
            previous = next(logs)
        except StopIteration:
            yield (self.guess_join_date(),
                   None,
                   self.guess_first_department(),
                   self.guess_first_position())
        else:
            if previous.action != ACTION_CHOICES.HIRED:
                yield (self.guess_join_date(),
                       previous.created_at,
                       previous.department_old,
                       previous.position_old)

            for log in logs:
                if previous:
                    yield (previous.created_at,
                           log.created_at,
                           previous.department_new,
                           previous.position_new)
                previous = log

            yield (previous.created_at,
                   None,
                   previous.department_new,
                   previous.position_new)


class DepartmentCache:
    def __init__(self):
        self._cache = None
        self.qs = Department.objects.order_by('tree_id', 'lft').values_list('id', 'parent_id', 'kind_id')
        self._is_populated = False

    def populate(self):
        if not self._is_populated:
            self._cache = dict(self._iter_dep_levels())
            self._is_populated = True

    def __getitem__(self, item):
        self.populate()
        return self._cache[item]

    def _iter_dep_levels(self):
        """
        Возвращает по три id старших подразделений в цепочке для каждого

        подразделения. Все продразделения, которые находятся м-ду
        ROOT и DIVISION приводятся к ближайшему старшему подразделению
        """

        parents = {}
        for dep_id, parent_id, kind in self.qs:
            if parent_id is None:
                value = [dep_id]
            else:
                value = parents[parent_id]
                if kind in (settings.DIS_DIRECTION_KIND_ID, settings.DIS_DIVISION_KIND_ID):
                    value = value + [dep_id]

            parents[dep_id] = value

            yield dep_id, tuple(value + [None] * (3 - len(value)))


class Counter:
    key_fields = (
        ('created_at', 'headcount') + tuple(f+'_id' for f in CountAndGrowth.FK_FIELDS)
    )
    fields = key_fields + CountAndGrowth.VALUE_FIELDS

    def __init__(self, given_attributes=None):
        attributes = given_attributes if given_attributes is not None else {}
        for attr in CountAndGrowth.VALUE_FIELDS:
            value = attributes.get(attr, 0)
            setattr(self, attr, value)
        for attr in self.key_fields:
            value = attributes.get(attr)
            setattr(self, attr, value)

    def __add__(self, other):
        """@type other: Counter"""
        if self.key == other.key or self.is_neutral:
            attributes = {
                attr: getattr(self, attr) + getattr(other, attr)
                for attr in CountAndGrowth.VALUE_FIELDS
            }
            attributes.update(other.as_dict(other.key_fields))
            return Counter(attributes)

        else:
            raise TypeError('Counters have incompatible keys')

    def __str__(self):
        return str(self.as_dict())

    @property
    def key(self):
        return tuple(getattr(self, attr) for attr in self.key_fields)

    @property
    def is_neutral(self):
        return self.key == (None,) * len(self.key_fields)

    def save(self):
        return CountAndGrowth.objects.create(**self.as_dict())

    def as_dict(self, fields=None):
        attrs = fields if fields is not None else self.fields
        return {
            attr: getattr(self, attr)
            for attr in attrs
        }


class CounterPart(object):
    # public:

    inner_migrations = (ACTION_CHOICES.TRANSFERRED,
                        ACTION_CHOICES.TRANSFERRED_AND_CHANGED_POSITION,
                        ACTION_CHOICES.MOVED,
                        ACTION_CHOICES.CHANGED_ORGANIZATION)

    def __init__(self, lenta_log, suffix, department_cache):
        self.entry = lenta_log
        self.suffix = suffix
        self.department_cache = department_cache

    def __bool__(self):
        return None not in (
            self.office_id, self.organization_id, self.department_0_id
        )

    def __str__(self):
        return '<%s, %s>' % (self.entry, self.suffix)

    @property
    def key(self):
        return tuple(getattr(self, attr) for attr in Counter.key_fields)

    @property
    def counter(self):
        return Counter({attr: getattr(self, attr) for attr in Counter.fields})

    # protected:

    @property
    def created_at(self):
        return self.entry.created_at.date()

    @property
    def office_id(self):
        return getattr(self.entry, 'office_%s_id' % self.suffix)

    @property
    def organization_id(self):
        return getattr(self.entry, 'organization_%s_id' % self.suffix)

    @property
    def headcount(self):
        return getattr(self.entry, 'headcount_%s' % self.suffix)

    @property
    def department_0_id(self):
        return self._department(0)

    @property
    def department_1_id(self):
        return self._department(1)

    @property
    def department_2_id(self):
        return self._department(2)

    @property
    def total(self):
        return {
            (ACTION_CHOICES.HIRED, 'old'): 0,
            (ACTION_CHOICES.HIRED, 'new'): +1,

            (ACTION_CHOICES.DISMISSED, 'old'): -1,
            (ACTION_CHOICES.DISMISSED, 'new'): 0,

            (ACTION_CHOICES.CHANGED_POSITION, 'old'): 0,
            (ACTION_CHOICES.CHANGED_POSITION, 'new'): 0,

            (ACTION_CHOICES.CHANGED_HEADCOUNT, 'old'): -1,
            (ACTION_CHOICES.CHANGED_HEADCOUNT, 'new'): +1,

            (ACTION_CHOICES.TRANSFERRED, 'old'): -1,
            (ACTION_CHOICES.TRANSFERRED, 'new'): +1,

            (ACTION_CHOICES.TRANSFERRED_AND_CHANGED_POSITION, 'old'): -1,
            (ACTION_CHOICES.TRANSFERRED_AND_CHANGED_POSITION, 'new'): +1,

            (ACTION_CHOICES.RETURNED, 'old'): 0,
            (ACTION_CHOICES.RETURNED, 'new'): +1,

            (ACTION_CHOICES.MOVED, 'old'): -1,
            (ACTION_CHOICES.MOVED, 'new'): +1,

            (ACTION_CHOICES.CHANGED_ORGANIZATION, 'old'): -1,
            (ACTION_CHOICES.CHANGED_ORGANIZATION, 'new'): +1,
        }[self.entry.action, self.suffix]

    @property
    def hired(self):
        if self.entry.action == ACTION_CHOICES.HIRED and self._is_new:
            return 1
        else:
            return 0

    @property
    def fired(self):
        if self.entry.action == ACTION_CHOICES.DISMISSED and self._is_old:
            return 1
        else:
            return 0

    @property
    def returned(self):
        if self.entry.action == ACTION_CHOICES.RETURNED and self._is_new:
            return 1
        else:
            return 0

    @property
    def joined(self):
        if self.entry.action in self.inner_migrations and self._is_new:
            return 1
        else:
            return 0

    @property
    def left(self):

        if self.entry.action in self.inner_migrations and self._is_old:
            return 1
        else:
            return 0

    # private (helpers):

    def _department(self, level):
        dep_id = getattr(self.entry, 'department_%s_id' % self.suffix)
        if dep_id is None:
            return None
        return self.department_cache[dep_id][level]

    @property
    def _is_new(self):
        return self.suffix == 'new'

    @property
    def _is_old(self):
        return self.suffix == 'old'


class Aggregator:
    suffixes = 'old', 'new'

    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.qs = LentaLog.objects.filter(created_at__range=(start, end))
        self.counters = defaultdict(Counter)
        self.department_cache = DepartmentCache()

    def run(self):
        for line, suffix in product(self.qs, self.suffixes):
            part = CounterPart(line, suffix, self.department_cache)
            if part:
                self.counters[part.key] += part.counter

        for counter in self.counters.values():
            counter.save()

        return len(self.counters)


class FlexChart(object):
    EXPORT_FIELDS = ('period',) + CountAndGrowth.ALL_FIELDS
    START = get_date_setting('CHART_START_DATE')
    HOMIE_OFFICE_ID = settings.HOMIE_OFFICE_ID
    LOOKUP_SEP = '__'
    STEP = 1

    value_fields = CountAndGrowth.VALUE_FIELDS

    # public:

    def __init__(self, countries_by,
                 start=None, end=None, step=None,
                 department_0=None, department_1=None,
                 cities=(), with_homies=False,
                 only_headcount=False, group_by=None):
        self.start = self.START if start is None else start
        self.end = date.today() if end is None else end
        self.step = self.STEP if step is None else step
        self.department_0 = department_0
        self.department_1 = department_1
        self.cities = cities
        self.exclude_homies = not with_homies
        self.only_headcount = only_headcount
        assert countries_by in (c[0] for c in COUNTRIES_BY)
        self.countries_by = countries_by
        self.group_by = group_by or GROUP_BY.city
        assert self.group_by in (g[0] for g in GROUP_BY)

        self._range_qs = None
        self._totals_qs = None
        self._raw_qs = None
        self._data = None
        self._all_cities = None
        self._totals = None
        self.used_geos = []
        self.grouped_delta = {}

    def __iter__(self):
        return iter(self._aggregate())

    @property
    def totals(self):
        if self._totals is None:
            self._totals = {
                counter[self._geo_attr]: counter['total']
                for counter in self._totals_queryset
            }
        return self._totals

    # private:

    @property
    def _range_queryset(self):
        if self._range_qs is None:
            self._range_qs = self._make_range_queryset()
        return self._range_qs

    @property
    def _totals_queryset(self):
        if self._totals_qs is None:
            self._totals_qs = self._make_totals_queryset()
        return self._totals_qs

    def _calc_total(self, counter, city):
        total = self.totals.get(city, 0) + counter['total']
        self.totals[city] = total
        return total

    def _calc_grouped_delta(self, counter):
        period = counter['period']
        delta = self.grouped_delta.get(period, 0) + counter['total']
        if delta != 0:
            self.grouped_delta[period] = delta
        else:
            if period in self.grouped_delta:
                del self.grouped_delta[period]

    def _calc_value(self, city, counter):
        value = {name: counter[name] for name in self.value_fields}
        if 'total' in value:
            value['delta'] = counter['total']
            self._calc_grouped_delta(counter)
            value['total'] = self._calc_total(counter, city)

        if all(f in value for f in ['hired', 'returned']):
            value['plus'] = sum(
                value[f] for f in ['hired', 'returned']
            )

        if all(f in value for f in ['fired']):
            value['minus'] = sum(value[f] for f in ['fired'])

        return value

    def _aggregate(self):
        ZERO_PERIOD = 0

        qs_iter = self._range_queryset.iterator()

        for geo, counters in groupby(qs_iter, itemgetter(self._geo_attr)):
            self.used_geos.append(geo)

            counter = next(counters)

            if (
                counter['period'] != ZERO_PERIOD
                and 'total' in self.value_fields
                and geo in self.totals
            ):
                value = {
                    f: self.totals[geo] if f == 'total' else 0
                    for f in chain(
                        self.value_fields, ['plus', 'minus', 'delta']
                    )
                }
                yield ZERO_PERIOD, geo, value

            yield counter['period'], geo, self._calc_value(geo, counter)
            for counter in counters:
                yield counter['period'], geo, self._calc_value(geo, counter)

        for geo in set(self.totals.keys()) - set(self.used_geos):
            self.used_geos.append(geo)
            value = {
                f: self.totals[geo] if f == 'total' else 0
                for f in chain(
                    self.value_fields, ['plus', 'minus', 'delta']
                )
            }
            yield ZERO_PERIOD, geo, value

    def _make_range_queryset(self):

        lookup = self._lookup
        lookup['created_at__range'] = (self.start, self.end)

        annotations = {counter_name: Sum(counter_name)
                       for counter_name in self.value_fields}

        return (
            CountAndGrowth.objects
            .extra(**self._extra)
            .filter(**lookup)
            .exclude(**self._exclude)
            .values(*self._grouping_attrs)
            .annotate(**annotations)
            .order_by(*self._grouping_attrs)
        )

    def _make_totals_queryset(self):

        lookup = self._lookup
        lookup['created_at__lte'] = self.start

        return (
            CountAndGrowth.objects
            .filter(**lookup)
            .exclude(**self._exclude)
            .values(self._geo_attr)
            .annotate(total=Sum('total'))
            .order_by(self._geo_attr)
        )

    @property
    def _extra(self):
        # noinspection PyProtectedMember
        db_table = CountAndGrowth._meta.db_table
        period = f'cast(ceil((date_part(\'day\', {db_table}.created_at::timestamp - %s::timestamp)) / %s) as integer)'
        return {
            'select': {
                'period': period
            },
            'select_params': (self.start, self.step)
        }

    @property
    def _lookup(self):
        lookup = {self._geo_attr + '__isnull': False}
        if self.department_0:
            lookup['department_0'] = self.department_0
        if self.department_1:
            lookup['department_1'] = self.department_1
        if self.cities:
            lookup[self._city_attr + '__in'] = self.cities
        if self.only_headcount:
            lookup['headcount'] = True

        return lookup

    @property
    def _exclude(self):
        if self.exclude_homies:
            return {'office_id': self.HOMIE_OFFICE_ID}
        else:
            return {}

    @property
    def _geo_attr(self):
        return {
            GROUP_BY.city: self._city_attr,
            GROUP_BY.country: self._country_attr,
        }[self.group_by]

    @property
    def _city_attr(self):
        return self.LOOKUP_SEP.join((self.countries_by, 'city'))

    @property
    def _country_attr(self):
        return self.LOOKUP_SEP.join((self.countries_by, 'city', 'country'))

    @property
    def _grouping_attrs(self):
        return self._geo_attr, 'period',

    def get_used_geos(self):
        model, order_fields = {
            GROUP_BY.city: (City,
                            ('country__position', 'country', 'position')),
            GROUP_BY.country: (Country,
                               ('position',)),
        }[self.group_by]
        return model.objects.filter(pk__in=self.used_geos).order_by(
            *order_fields
        ).values_list('id', flat=True)
