# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

from django.utils.encoding import force_bytes
from pytz import timezone, utc
from pytz.tzinfo import DstTzInfo, StaticTzInfo


class ShrinkBoundError(Exception):
    pass


class ShrinkedTzInfoMixin(object):
    _min_dt = _max_dt = None  # overriden in subclass

    def fromutc(self, dt):
        local_dt = super(ShrinkedTzInfoMixin, self).fromutc(dt)
        if not (self._min_dt <= dt.replace(tzinfo=None) <= self._max_dt):
            raise ShrinkBoundError
        return local_dt

    def localize(self, dt, is_dst=False):
        local_dt = super(ShrinkedTzInfoMixin, self).localize(dt, is_dst)
        local_dt.astimezone(utc)  # fromutc will raise ShrinkBoundError if local_dt is out of bound
        return local_dt


class ShrinkedStaticTzInfo(ShrinkedTzInfoMixin, StaticTzInfo):
    pass


class ShrinkedDstTzInfo(ShrinkedTzInfoMixin, DstTzInfo):
    pass


def _make_subclass(name, base, class_dict):
    return type(force_bytes(name), (base,), class_dict)


class ShrinkedTimezones(object):
    def __init__(self, min_dt, max_dt):
        if not (min_dt.tzinfo is max_dt.tzinfo is None) or min_dt >= max_dt:
            raise ValueError

        self._min_dt = min_dt
        self._max_dt = max_dt
        self._cache = {}

    def get(self, zone):
        try:
            tzinfo = self._cache[zone]
        except KeyError:
            tzinfo = self._cache[zone] = self._build_shrinked_tzinfo(zone)
        return tzinfo

    def _build_shrinked_tzinfo(self, zone):
        full_tzinfo = timezone(zone)
        if isinstance(full_tzinfo, StaticTzInfo):
            return full_tzinfo

        min_dt, max_dt = self._min_dt, self._max_dt
        utc_times = full_tzinfo._utc_transition_times
        last_index = len(utc_times) - 1

        # index before value that greater than min_dt or last_index
        left_index = next((max(0, idx - 1)
                           for idx, utc_time in enumerate(utc_times)
                           if utc_time > min_dt),
                          last_index)

        # index of value that less than max_dt or left_index
        right_index = next((idx
                            for idx in range(last_index, left_index, -1)
                            if utc_times[idx] < max_dt),
                           left_index)

        assert left_index <= right_index

        if left_index == right_index:  # DstTzInfo => StaticTzInfo
            utcoffset, _dst, tzname = full_tzinfo._transition_info[left_index]
            static_tzinfo_class = _make_subclass(zone,
                                                 ShrinkedStaticTzInfo,
                                                 dict(zone=zone, _min_dt=min_dt, _max_dt=max_dt,
                                                      _utcoffset=utcoffset, _tzname=tzname))
            return static_tzinfo_class()
        else:
            utc_transition_times = full_tzinfo._utc_transition_times[left_index:right_index + 1]
            transition_info = full_tzinfo._transition_info[left_index:right_index + 1]
            dst_tzinfo_class = _make_subclass(zone,
                                              ShrinkedDstTzInfo,
                                              dict(zone=zone, _min_dt=min_dt, _max_dt=max_dt,
                                                   _utc_transition_times=utc_transition_times,
                                                   _transition_info=transition_info))
            return dst_tzinfo_class()
