import collections
import copy
import datetime
import decimal
import logging
import threading
import time

from django.db import transaction
from django.db.models import Q
from django.utils import timezone

import cars.settings
from ..models.car_model import CarModel
from ..models.tariff import CarsharingTariff
from ..models.tariff_plan import CarsharingTariffPlan
from ..models.tariff_plan_entry import CarsharingTariffPlanEntry


LOGGER = logging.getLogger(__name__)


class CarsharingTariffPicker:

    class TariffMissingError(Exception):
        pass

    def pick_ride_and_parking_tariffs(self, dt=None):
        if dt is None:
            dt = timezone.now()

        ride_tariff, parking_tariff = self._pick_tariffs(  # pylint: disable=unbalanced-tuple-unpacking
            dt=dt,
            types=[
                CarsharingTariff.Type.RIDE,
                CarsharingTariff.Type.PARKING,
            ],
        )

        return ride_tariff, parking_tariff

    def pick_ride_tariff(self, dt=None):
        if dt is None:
            dt = timezone.now()
        return self._pick_tariff(dt=dt, type_=CarsharingTariff.Type.RIDE)

    def pick_parking_tariff(self, dt=None):
        if dt is None:
            dt = timezone.now()
        return self._pick_tariff(dt=dt, type_=CarsharingTariff.Type.PARKING)

    def _pick_tariff(self, dt, type_):
        return self._pick_tariffs(dt=dt, types=[type_])[0]

    def _pick_tariffs(self, dt, types):
        candidates = (
            CarsharingTariff.objects
            .filter(
                Q(type__in=[t.value for t in types])
                &
                (Q(dow=dt.isoweekday()) | Q(dow__isnull=True))
                &
                (Q(active_from__lte=dt.time()) | Q(active_from__isnull=True))
                &
                (Q(active_to__gt=dt.time()) | Q(active_to__isnull=True))
            )
        )

        candidates_per_type = collections.defaultdict(list)
        for candidate in candidates:
            candidates_per_type[candidate.get_type()].append(candidate)

        tariffs = []
        for type_ in types:
            candidates = candidates_per_type[type_]
            if not candidates:
                raise self.TariffMissingError
            tariff = sorted(candidates, key=self._get_tariff_priority)[0]
            tariffs.append(tariff)

        return tariffs

    def _get_tariff_priority(self, tariff):
        if tariff.active_from is None:
            return 1
        return 0


class CarsharingTariffManager:

    class Error(Exception):
        pass

    def __init__(self, tariff_picker=None):
        if tariff_picker is None:
            tariff_picker = CarsharingTariffPickerV2()

        self._tariff_picker = tariff_picker
        self._tariff_plan_updater = CarsharingTariffPlanUpdater()

    @classmethod
    def from_settings(cls):
        return cls()

    def pick_tariff(self, requests):
        return self._tariff_picker.pick_tariff(requests=requests)

    def sync_tariff_plans(self):
        self._tariff_picker.sync_tariff_plans()

    def create_tariff_plan(self, created_by, name, entries, user_tag=None, car_model_code=None):
        try:
            plan = self._tariff_plan_updater.create(
                created_by=created_by,
                name=name,
                entries=entries,
                user_tag=user_tag,
                car_model_code=car_model_code,
            )
        except CarsharingTariffPlanUpdater.Error as e:
            raise self.Error(str(e))

        return plan

    def update_tariff_plan(self, tariff_plan, name, entries, user_tag=None, car_model_code=None):
        try:
            plan = self._tariff_plan_updater.update(
                tariff_plan=tariff_plan,
                name=name,
                entries=entries,
                user_tag=user_tag,
                car_model_code=car_model_code,
            )
        except CarsharingTariffPlanUpdater.Error as e:
            raise self.Error(str(e))

        return plan


class CarsharingTariffPlanUpdater:

    class Error(Exception):
        pass

    class TariffPlanCarModelInvalidError(Error):
        def __init__(self):
            super().__init__('car_model_code.invalid')

    class TariffPlanEntriesInconsistent(Error):
        def __init__(self):
            super().__init__('tariff_plan.entries.inconsistent')

    class TariffPlanExistsError(Error):
        def __init__(self):
            super().__init__('tariff_plan.exists')

    def create(self, created_by, name, entries, user_tag, car_model_code):
        if car_model_code:
            try:
                car_model = CarModel.objects.get(code=car_model_code)
            except CarModel.DoesNotExist:
                raise self.TariffPlanCarModelInvalidError
        else:
            car_model = None

        if CarsharingTariffPlan.objects.filter(name=name).exists():
            raise self.TariffPlanExistsError

        self._validate_tariff_plan_entries(entries)

        with transaction.atomic(savepoint=False):
            tariff_plan = CarsharingTariffPlan.objects.create(
                created_by=created_by,
                created_at=timezone.now(),
                name=name,
                user_tag=user_tag,
                car_model=car_model,
            )
            for entry in entries:
                entry.tariff_plan = tariff_plan
                entry.save()

        return tariff_plan

    def update(self, tariff_plan, name, entries, user_tag, car_model_code):
        self._validate_tariff_plan_entries(entries)

        with transaction.atomic(savepoint=False):
            tariff_plan.name = name
            tariff_plan.user_tag = user_tag
            tariff_plan.car_model_id = car_model_code
            tariff_plan.save()

            CarsharingTariffPlanEntry.objects.filter(tariff_plan=tariff_plan).delete()
            for entry in entries:
                entry.tariff_plan = tariff_plan
                entry.save()

        return tariff_plan

    def _validate_tariff_plan_entries(self, entries):
        entries_per_dow = collections.defaultdict(list)

        for entry in entries:
            if entry.get_effective_end_time() <= entry.get_effective_start_time():
                raise self.TariffPlanEntriesInconsistent
            entries_per_dow[entry.day_of_week].append(entry)

        for dow_entries in entries_per_dow.values():
            dow_entries = sorted(dow_entries, key=lambda entry: entry.get_effective_start_time())

            prev_dow_entry = None
            for dow_entry in dow_entries:
                if prev_dow_entry is None:
                    prev_dow_entry = dow_entry
                    continue
                if prev_dow_entry.get_effective_end_time() > dow_entry.get_effective_start_time():
                    raise self.TariffPlanEntriesInconsistent


CarsharingTariffPickerRequest = collections.namedtuple(
    'CarsharingTariffPickerRequest',
    [
        'date',
        'timezone',  # Time zone of times in the result.
        'user_tags',
        'car_model_code',
        'is_plus_user',
    ],
)

CarsharingTariffPickerResult = collections.namedtuple(
    'CarsharingTariffPickerResult',
    [
        'start_time',
        'ride_cost_per_minute',
        'parking_cost_per_minute',
        'free_parking',
    ],
)

CarsharingTariffPickerFreeParkingResult = collections.namedtuple(
    'CarsharingTariffPickerFreeParkingResult',
    [
        'end_date',
        'next_tariff',
    ],
)


class CarsharingTariffPickerV2:

    def __init__(self, tariff_holder=None):
        if tariff_holder is None:
            tariff_holder = CarsharingTariffHolder()
            tariff_holder.start()
        self._tariff_holder = tariff_holder

    def sync_tariff_plans(self):
        self._tariff_holder.sync_tariff_plans()

    def pick_tariff(self, requests):
        results = []
        for request in requests:
            result = self._do_pick_tariff(request)
            results.append(result)
        return results

    def _do_pick_tariff(self, request):
        candidate_tariff_plan_entries = self._tariff_holder.get_matching_tariff_plan_entries(
            max_start_dt=request.date,
            min_end_dt=request.date,
            day_of_week=request.date.isoweekday(),
            user_tags=request.user_tags,
            car_model_code=request.car_model_code,
        )

        if not candidate_tariff_plan_entries:
            # Should never happen.
            LOGGER.error('no candidate tariffs for %s', request)
            return CarsharingTariffPickerResult(
                start_time=datetime.time(hour=0),
                ride_cost_per_minute=5,
                parking_cost_per_minute=2,
                free_parking=None,
            )

        ranked_tariff_plan_entries = self._rank_tariff_plan_entries(
            candidate_tariff_plan_entries,
        )
        tariff_plan_entry = ranked_tariff_plan_entries[0]

        # For Plus users, alter tariff by applying discount
        tariff_plan_entry = self.apply_plus_tariff_conditions_if_applicable(
            tariff_plan_entry,
            request
        )

        if request.timezone is None:
            local_start_time = tariff_plan_entry.get_effective_start_time()
        else:
            local_start_time = self._replace_time_timezone(
                src_time=tariff_plan_entry.get_effective_start_time(),
                src_tz=tariff_plan_entry.tariff_plan.get_timezone(),
                dst_tz=request.timezone,
            )

        free_parking_result = None
        if tariff_plan_entry.parking_cost_per_minute == 0:
            after_free_parking_entry, free_period_end_date = self._pick_next_non_free_tariff(
                request,
                current_entry=tariff_plan_entry,
            )

            # Special adjustment for Plus users.
            # The time of free parking for them is extended, so if this applies
            # we just move the border or free parking time to the end of the extended
            # period and repeat the search of the next non-free tariff from that point
            after_free_parking_entry, needs_recalc = (
                self.adjust_after_free_parking_entry_if_plus_user(
                    after_free_parking_entry,
                    request
                )
            )
            if needs_recalc:
                modified_request = CarsharingTariffPickerRequest(
                    date=free_period_end_date,
                    timezone=request.timezone,
                    user_tags=request.user_tags,
                    car_model_code=request.car_model_code,
                    is_plus_user=request.is_plus_user,
                )
                after_free_parking_entry, free_period_end_date = self._pick_next_non_free_tariff(
                    modified_request,
                    current_entry=after_free_parking_entry,
                )

            if free_period_end_date is not None:
                free_parking_result = CarsharingTariffPickerFreeParkingResult(
                    next_tariff=CarsharingTariffPickerResult(
                        start_time=after_free_parking_entry.start_time,
                        ride_cost_per_minute=after_free_parking_entry.ride_cost_per_minute,
                        parking_cost_per_minute=after_free_parking_entry.parking_cost_per_minute,
                        free_parking=None,
                    ),
                    end_date=free_period_end_date,
                )

        result = CarsharingTariffPickerResult(
            start_time=local_start_time,
            ride_cost_per_minute=tariff_plan_entry.ride_cost_per_minute,
            parking_cost_per_minute=tariff_plan_entry.parking_cost_per_minute,
            free_parking=free_parking_result,
        )

        result = self.adjust_after_night_parking_cost_if_plus_user(result, request)

        return result

    def adjust_after_free_parking_entry_if_plus_user(self, after_free_parking_entry, request):
        settings = cars.settings.CARSHARING['plus']
        if not request.is_plus_user:
            return after_free_parking_entry, False

        modified_entry = copy.deepcopy(after_free_parking_entry)
        if settings['extended_free_parking_start'] <= modified_entry.get_effective_start_time() and \
           settings['extended_free_parking_finish'] >= modified_entry.get_effective_end_time():
            modified_entry.finish_time = settings['extended_free_parking_finish']
            modified_entry.parking_cost_per_minute = decimal.Decimal(0.0)
            return modified_entry, True

        return after_free_parking_entry, False

    def adjust_after_night_parking_cost_if_plus_user(self, result, request):
        settings = cars.settings.CARSHARING['plus']
        if result.free_parking is None or not request.is_plus_user:
            return result

        # Multiply costs for the next parking by discount multiply to show correctly in the app
        new_ride_cost_per_minute = round(
            result.free_parking.next_tariff.ride_cost_per_minute * settings['discount_multiplier'],
            1
        )
        new_parking_cost_per_minute = round(
            result.free_parking.next_tariff.parking_cost_per_minute * settings['discount_multiplier'],
            1
        )

        new_free_parking_result = CarsharingTariffPickerFreeParkingResult(
            next_tariff=CarsharingTariffPickerResult(
                start_time=result.free_parking.next_tariff.start_time,
                ride_cost_per_minute=new_ride_cost_per_minute,
                parking_cost_per_minute=new_parking_cost_per_minute,
                free_parking=None,
            ),
            end_date=result.free_parking.end_date,
        )

        new_result = CarsharingTariffPickerResult(
            start_time=result.start_time,
            ride_cost_per_minute=result.ride_cost_per_minute,
            parking_cost_per_minute=result.parking_cost_per_minute,
            free_parking=new_free_parking_result,
        )

        return new_result

    def apply_plus_tariff_conditions_if_applicable(self, tariff_plan_entry, request):
        settings = cars.settings.CARSHARING['plus']
        if not request.is_plus_user:
            return tariff_plan_entry

        modified_entry = copy.deepcopy(tariff_plan_entry)

        # Modify prices
        modified_entry.ride_cost_per_minute *= settings['discount_multiplier']
        modified_entry.parking_cost_per_minute *= settings['discount_multiplier']

        # Round modified prices
        modified_entry.ride_cost_per_minute = round(modified_entry.ride_cost_per_minute, 1)
        modified_entry.parking_cost_per_minute = round(modified_entry.parking_cost_per_minute, 1)

        # Extend night parking if applicable
        if modified_entry.parking_cost_per_minute == 0:
            modified_entry.end_time = settings['extended_free_parking_finish']

        if settings['extended_free_parking_start'] <= modified_entry.get_effective_start_time() and \
           settings['extended_free_parking_finish'] >= modified_entry.get_effective_end_time():
            modified_entry.finish_time = settings['extended_free_parking_finish']
            modified_entry.parking_cost_per_minute = decimal.Decimal(0.0)

        return modified_entry

    def _rank_tariff_plan_entries(self, candidate_tariff_plan_entries):
        return sorted(
            candidate_tariff_plan_entries,
            key=self._get_tariff_plan_entry_sorting_key,
            reverse=True,
        )

    def _get_tariff_plan_entry_sorting_key(self, tariff_plan_entry):
        tariff_plan = tariff_plan_entry.tariff_plan

        priority = 1
        if tariff_plan.user_tag is not None:
            priority |= 1 << 3
        if tariff_plan.car_model_id is not None:
            priority |= 1 << 2
        if tariff_plan_entry.day_of_week is not None:
            priority |= 1 << 1

        return priority, tariff_plan.created_at

    def _pick_next_non_free_tariff(self, request, current_entry):
        candidate_tariff_plan_entries = self._tariff_holder.get_matching_tariff_plan_entries(
            user_tags=request.user_tags,
            car_model_code=request.car_model_code,
            func=lambda entry: entry.parking_cost_per_minute > 0,
        )

        next_non_free_tariff_plan_entry = None
        start_date = None
        date_localizer = DateLocalizer()

        current_priority = self._get_tariff_plan_entry_sorting_key(current_entry)
        for entry in candidate_tariff_plan_entries:
            base_date = date_localizer.localize(
                dt=request.date,
                tz=entry.tariff_plan.get_timezone(),
            )
            is_next_entry_next_day = (
                current_entry.day_of_week
                and entry.day_of_week
                and current_entry.day_of_week - (7 if current_entry.day_of_week == 7 else 0) < entry.day_of_week
            )
            if is_next_entry_next_day:
                next_non_free_tariff_plan_entry = entry
                start_date = self._replace_time(
                    dt=base_date,
                    t=entry.get_effective_start_time(),
                )
                while start_date.isoweekday() != entry.day_of_week:
                    start_date += datetime.timedelta(days=1)
                break

            is_current_entry_expires = (
                current_entry.get_effective_start_time() < entry.get_effective_end_time()
                and current_entry.get_effective_end_time() < entry.get_effective_end_time()
            )
            if is_current_entry_expires:
                next_non_free_tariff_plan_entry = entry
                start_date = self._replace_time(
                    dt=base_date,
                    t=max(
                        current_entry.get_effective_end_time(),
                        entry.get_effective_start_time(),
                    ),
                )
                break

            is_greater_priority = self._get_tariff_plan_entry_sorting_key(entry) > current_priority
            if is_greater_priority:
                next_non_free_tariff_plan_entry = entry
                start_date = self._replace_time(
                    dt=base_date,
                    t=entry.get_effective_start_time(),
                )
                break

        return next_non_free_tariff_plan_entry, start_date

    def _replace_time_timezone(self, src_time, src_tz, dst_tz):
        if src_tz == dst_tz:
            return src_time

        dt = self._replace_time(dt=datetime.datetime.now(), t=src_time)
        dst_time = src_tz.localize(dt).astimezone(dst_tz).time()

        return dst_time

    def _replace_time(self, dt, t):
        return dt.replace(
            hour=t.hour,
            minute=t.minute,
            second=t.second,
            microsecond=t.microsecond,
        )


class CarsharingTariffHolder:

    def __init__(self, update_interval=datetime.timedelta(minutes=5)):
        self._tariff_plans = None
        self._sorted_tariff_plan_entries = None
        self._update_interval = update_interval
        self._updater_thread = None
        self._started = threading.Event()

    def start(self):
        self._updater_thread = threading.Thread(target=self._run_updater_loop)
        self._updater_thread.setDaemon(True)
        self._updater_thread.start()
        self._started.wait()

    def sync_tariff_plans(self):
        self._tariff_plans = tuple(
            CarsharingTariffPlan.objects
            .using(cars.settings.DB_RO_ID)
            .select_related('car_model')
            .prefetch_related('entries')
            .all()
        )
        self._sorted_tariff_plan_entries = tuple(
            sorted(
                [
                    entry
                    for tariff_plan in self._tariff_plans
                    for entry in tariff_plan.entries.all()
                ],
                key=lambda entry: (entry.day_of_week or 0, entry.get_effective_start_time()),
            )
        )

    def get_matching_tariff_plan_entries(self, user_tags, car_model_code,
                                         max_start_dt=None, min_end_dt=None,
                                         day_of_week=None, func=None):

        assert self._tariff_plans is not None, 'tariff plans are not synced'

        date_localizer = DateLocalizer()

        matching_entries = []
        for entry in self._sorted_tariff_plan_entries:
            tz = entry.tariff_plan.get_timezone()
            max_start_dt = date_localizer.localize(dt=max_start_dt, tz=tz)
            min_end_dt = date_localizer.localize(dt=min_end_dt, tz=tz)

            if max_start_dt is not None:
                if entry.day_of_week is None or entry.day_of_week == max_start_dt.isoweekday():
                    if entry.get_effective_start_time() > max_start_dt.time():
                        continue
                elif entry.day_of_week > max_start_dt.isoweekday() + (7 if max_start_dt.isoweekday() == 1 else 0):
                    continue

            if min_end_dt is not None:
                if entry.day_of_week is None or entry.day_of_week == min_end_dt.isoweekday():
                    if entry.get_effective_end_time() <= min_end_dt.time():
                        continue
                elif entry.day_of_week > min_end_dt.isoweekday() + (7 if min_end_dt.isoweekday() == 1 else 0):
                    continue

            if entry.day_of_week and day_of_week and entry.day_of_week != day_of_week:
                continue

            if entry.tariff_plan.user_tag and entry.tariff_plan.user_tag not in user_tags:
                continue

            if entry.tariff_plan.car_model_id and entry.tariff_plan.car_model_id != car_model_code:
                continue

            if func is not None and not func(entry):
                continue

            matching_entries.append(entry)

        return matching_entries

    def _run_updater_loop(self):
        self.sync_tariff_plans()
        self._started.set()

        while True:
            time.sleep(self._update_interval.total_seconds())
            self.sync_tariff_plans()


class DateLocalizer:

    def __init__(self):
        self._cache = {}

    def localize(self, dt, tz):
        if dt is None:
            return None

        if tz not in self._cache:
            self._cache[tz] = {}

        tz_cache = self._cache[tz]

        if dt not in tz_cache:
            local_dt = dt.astimezone(tz)
            tz_cache[dt] = local_dt

        localized_dt = tz_cache[dt]

        return localized_dt
