from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
from enum import Enum
from typing import Optional

from dateutil.relativedelta import relativedelta
from pytz.tzinfo import BaseTzInfo

from sendr_utils import without_none


def timezone_today(timezone_: BaseTzInfo) -> date:
    return datetime.now(timezone_).date()


def safe_localize(datetime_: datetime, timezone_: BaseTzInfo) -> datetime:
    if datetime_.tzinfo is not None:
        return datetime_
    return timezone_.localize(datetime_)


def is_today(value: date, timezone_: BaseTzInfo) -> bool:
    return value == timezone_today(timezone_)


def is_tomorrow(value: date, timezone_: BaseTzInfo) -> bool:
    return value - timedelta(days=1) == timezone_today(timezone_)


def time_between(time_: time, period_start: time, period_end: time) -> bool:
    """
    Returns True if `time_` is inside `[period_start; period_end)`.
    """
    if period_start < period_end:
        return period_start <= time_ and time_ < period_end
    else:
        return period_start <= time_ or time_ < period_end


def time_distance(t1: time, t2: time) -> timedelta:
    min_date = datetime.min.date()
    dt1 = datetime.combine(min_date, t1)
    dt2 = datetime.combine(min_date, t2)
    raw_diff = max(dt1 - dt2, dt2 - dt1)
    return min(raw_diff, timedelta(days=1) - raw_diff)


class Period(Enum):
    AM = 'am'
    PM = 'pm'


@dataclass(frozen=True)
class UserTime:
    """
    Allows to store and use user provided time which could be either absolute or relative.
    """

    hour: Optional[int] = None
    minute: Optional[int] = None
    second: Optional[int] = None
    period: Optional[Period] = None
    relative: bool = False

    @staticmethod
    def is_12_hour(hour: int) -> bool:
        return 1 <= hour <= 12

    @classmethod
    def convert_hour_12_to_24(cls, hour: int, period: Period) -> int:
        assert cls.is_12_hour(hour)
        if hour == 12:
            return 0 if period is Period.AM else 12
        return hour if period is Period.AM else hour + 12

    @property
    def fixed(self) -> bool:
        return any((
            self.relative,
            self.period is not None,
            self.hour is not None and not self.is_12_hour(self.hour),
        ))

    def get_absolute(self, now: datetime) -> datetime:
        assert self.relative, 'Time must be relative.'
        return now + timedelta(**without_none({
            'hours': self.hour,
            'minutes': self.minute,
            'seconds': self.second,
        }))

    def get_assumed_time(self, assume_period: Period) -> time:
        """
        Fills period with `assume_period` if it's not provided. Returns time.
        """
        assert not self.relative, 'Time must be absolute.'
        hour = self.hour or 0
        minute = self.minute or 0
        second = self.second or 0
        period = self.period or assume_period

        if self.is_12_hour(hour):
            hour = self.convert_hour_12_to_24(hour, period)
        return time(hour=hour, minute=minute, second=second)

    def get_fit_time(self, fit_period_start: time) -> time:
        """
        Assumes period so that resulting time is within 12-hour period starting with `fit_period_start`.
        If it is not possible returns any assumed time.

        Expected use case for that method is to pass fit_period_start equal to usual work morning start.
        """
        assert not self.relative, 'Time must be absolute.'
        am_time = self.get_assumed_time(Period.AM)
        pm_time = self.get_assumed_time(Period.PM)

        # timedelta doesn't work with time
        fit_period_end = fit_period_start.replace(hour=(fit_period_start.hour + 12) % 24)

        if time_between(am_time, fit_period_start, fit_period_end):
            return am_time
        else:
            return pm_time


@dataclass(frozen=True)
class UserDate:
    day: Optional[int] = None
    month: Optional[int] = None
    year: Optional[int] = None
    weeks: Optional[int] = None
    weekday: Optional[int] = None
    relative: bool = False

    def get_absolute(self, today: date) -> date:
        assert self.relative, 'Date must be relative.'
        return today + relativedelta(**without_none({
            'days': self.day,
            'months': self.month,
            'years': self.year,
            'weeks': self.weeks,
        }))

    def get_assumed_date(self, today: date) -> date:
        assert not self.relative, 'Date must be absolute.'
        if self.weekday:
            days_diff = (self.weekday - 1 - today.weekday()) % 7
            return today + timedelta(days=days_diff)
        result = date(**{
            'day': today.day if self.day is None else self.day,
            'month': today.month if self.month is None else self.month,
            'year': today.year if self.year is None else self.year,
        })
        if result < today:
            if self.month is None:
                result += relativedelta(months=1)
            elif self.year is None:
                result += relativedelta(years=1)
        return result

    def get_date(self, today: date) -> date:
        if self.relative:
            return self.get_absolute(today)
        else:
            return self.get_assumed_date(today)
