# coding=utf-8
from __future__ import unicode_literals

import numpy as np

from travel.avia.shared_flights.lib.python.date_utils.date_index import DateIndex
from travel.avia.shared_flights.lib.python.date_utils.date_matcher import DateMatcher

BITS_PER_ELEM = 63


def get_all_one_bits(num_bits):
    result = 0
    for pos in range(0, num_bits):
        result = result | (1 << pos)
    return result

ALL_ONE_BITS = get_all_one_bits(BITS_PER_ELEM)


# Efficiently stores dates as bits
class DateMask(object):
    def __init__(self, max_days):
        self.mask = np.zeros(max_days//BITS_PER_ELEM + 1, dtype=np.int64)

    def get_pos(self, pos):
        return (self.mask[pos//BITS_PER_ELEM] >> (pos % BITS_PER_ELEM)) % 2 == 1

    def add_pos(self, pos):
        self.mask[pos//BITS_PER_ELEM] = self.mask[pos//BITS_PER_ELEM] | (1 << (pos % BITS_PER_ELEM))

    def remove_pos(self, pos):
        self.mask[pos//BITS_PER_ELEM] = self.mask[pos//BITS_PER_ELEM] & (ALL_ONE_BITS ^ (1 << (pos % BITS_PER_ELEM)))

    def is_empty(self):
        for elem in self.mask:
            if elem:
                return False
        return True


# Efficiently manages date mask objects
class DateMaskMatcher(object):

    # To compare date masks efficiemtly, their start dates should be equal
    def __init__(self, date_index, start_date, max_days):
        self._start_date = start_date
        self._date_index = date_index
        self._date_matcher = DateMatcher(self._date_index)
        self._start_index = date_index.get_index(start_date)
        self._max_days = max_days

    def new_date_mask(self):
        return DateMask(self._max_days)

    def add_date(self, date, date_mask):
        self.add_date_str(DateIndex.get_str(date), date_mask)

    def add_date_str(self, date_str, date_mask):
        pos = self._date_index.get_index_for_date_str(date_str) - self._start_index
        if pos < 0 or pos > self._max_days:
            return
        date_mask.add_pos(pos)

    def add_range(self, operating_from_str, operating_until_str, operating_on, date_mask):
        index_from = self._date_index.get_index_for_date_str(operating_from_str)
        index_until = self._date_index.get_index_for_date_str(operating_until_str)
        if index_from < self._start_index:
            index_from = self._start_index
        if index_until > self._start_index + self._max_days:
            index_until = self._start_index + self._max_days
        operating_on_bits = DateMatcher.get_bits(operating_on)
        for day_index in range(index_from, index_until+1):
            if self._date_matcher.operates_on_index(day_index, operating_on_bits):
                date_mask.add_pos(day_index - self._start_index)

    # Both masks should be created by this matcher
    def add_mask(self, date_mask_to_update, date_mask_to_read):
        size = len(date_mask_to_update.mask)
        if size > len(date_mask_to_read.mask):
            size = len(date_mask_to_read.mask)
        for index in range(0, size):
            date_mask_to_update.mask[index] = date_mask_to_update.mask[index] | date_mask_to_read.mask[index]

    # Both masks should be created by this matcher
    def intersect_mask(self, date_mask_to_update, date_mask_to_read):
        size = len(date_mask_to_update.mask)
        if size > len(date_mask_to_read.mask):
            size = len(date_mask_to_read.mask)
        for index in range(0, size):
            date_mask_to_update.mask[index] = date_mask_to_update.mask[index] & date_mask_to_read.mask[index]

    def shift_days(self, date_mask, days):
        if not days:
            return

        max_pos = BITS_PER_ELEM*len(date_mask.mask) - 1
        if days > 0:
            for pos in range(max_pos - days, -1, -1):
                if date_mask.get_pos(pos):
                    date_mask.add_pos(pos + days)
                else:
                    date_mask.remove_pos(pos + days)

            for pos in range(days-1, -1, -1):
                date_mask.remove_pos(pos)

        if days < 0:
            days = -days
            for pos in range(days, max_pos+1, 1):
                if date_mask.get_pos(pos):
                    date_mask.add_pos(pos - days)
                else:
                    date_mask.remove_pos(pos - days)

            for pos in range(max_pos-days+1, max_pos+1, 1):
                date_mask.remove_pos(pos)
        return

    def remove_date(self, date, date_mask):
        self.remove_date_str(DateIndex.get_str(date), date_mask)

    def remove_date_str(self, date_str, date_mask):
        pos = self._date_index.get_index_for_date_str(date_str) - self._start_index
        if pos < 0 or pos > self._max_days:
            return
        date_mask.remove_pos(pos)

    def remove_range(self, operating_from_str, operating_until_str, operating_on, date_mask):
        index_from = self._date_index.get_index_for_date_str(operating_from_str)
        index_until = self._date_index.get_index_for_date_str(operating_until_str)
        if index_from < self._start_index:
            index_from = self._start_index
        if index_until > self._start_index + self._max_days:
            index_until = self._start_index + self._max_days
        operating_on_bits = DateMatcher.get_bits(operating_on)
        for day_index in range(index_from, index_until+1):
            if self._date_matcher.operates_on_index(day_index, operating_on_bits):
                date_mask.remove_pos(day_index - self._start_index)

    # Both masks should be created by this matcher
    def remove_mask(self, date_mask_to_update, date_mask_to_read):
        size = len(date_mask_to_update.mask)
        if size > len(date_mask_to_read.mask):
            size = len(date_mask_to_read.mask)
        for index in range(0, size):
            date_mask_to_update.mask[index] = date_mask_to_update.mask[index] & (ALL_ONE_BITS ^ date_mask_to_read.mask[index])

    # Returns mask which is operating on dates from date_mask1 which are not in date_mask2
    # Both masks should be created by this matcher
    def diff1(self, date_mask1, date_mask2):
        result = self.new_date_mask()
        size = len(date_mask1.mask)
        if size > len(date_mask2.mask):
            size = len(date_mask2.mask)
        if size > len(result.mask):
            size = len(result.mask)
        for index in range(0, size):
            result.mask[index] = date_mask1.mask[index] & (ALL_ONE_BITS ^ date_mask2.mask[index])
        return result

    def get_dates(self, date_mask):
        result = []
        pos = self._start_index
        for elem in date_mask.mask:
            shift = 0
            while elem > 0:
                if pos + shift - self._start_index > self._max_days:
                    return result
                if elem % 2:
                    result.append(self._date_index.get_date_int(pos+shift))
                shift += 1
                elem = elem >> 1
            pos += BITS_PER_ELEM
        return result

    def get_date_indexes(self, date_mask):
        result = []
        pos = self._start_index
        for elem in date_mask.mask:
            shift = 0
            while elem > 0:
                if elem % 2:
                    result.append(pos+shift)
                shift += 1
                elem = elem >> 1
            pos += BITS_PER_ELEM
        return result

    # Converts days from mask into (operating_from, operating_to, operating_on) triplets
    def generate_masks(self, date_mask):
        result = []
        date_indexes = self.get_date_indexes(date_mask)
        if not date_indexes:
            return result

        start = date_indexes[0]
        max_week_index = 0
        weeks = {}

        for date_index in date_indexes:
            week_index = (date_index - start) // 7
            weekday = self._date_index.get_weekday_by_index(date_index)
            weeks[week_index] = (weeks.get(week_index) or 0) | (1 << int(weekday))
            if week_index > max_week_index:
                max_week_index = week_index

        mask_start_week = 0
        mask_end_week = 0
        mask_operating_days = weeks.get(0) or 0
        for week_index in range(1, max_week_index+1):
            operating_days = weeks.get(week_index) or 0
            if operating_days and operating_days == mask_operating_days:
                mask_end_week = week_index
                continue
            if mask_start_week >= 0 and mask_operating_days:
                mask = self.create_dates_triple(start+mask_start_week*7, start+mask_end_week*7+6, mask_operating_days)
                result.append(mask)
                mask_start_week = -1
                mask_end_week = -1
                mask_operating_days = 0
            if not operating_days:
                continue
            mask_start_week = week_index
            mask_end_week = week_index
            mask_operating_days = operating_days

        if mask_start_week >= 0:
            mask = self.create_dates_triple(start+mask_start_week*7, start+mask_end_week*7+6, mask_operating_days)
            result.append(mask)

        return result

    def create_dates_triple(self, start_index, end_index, operating_bits):
        operates_on = 0
        for i in range(1, 8):
            if operating_bits & (1 << i) > 0:
                operates_on = operates_on*10 + i

        if operates_on < 10 and end_index-start_index <= 7:
            # when the mask is actually a day
            for date_index in range(start_index, end_index+1):
                if self._date_index.get_weekday_by_index(date_index) == operates_on:
                    return (
                        self._date_index.get_date_str(date_index),
                        self._date_index.get_date_str(date_index),
                        operates_on,
                    )

        return (
            self._date_index.get_date_str(start_index),
            self._date_index.get_date_str(end_index),
            operates_on,
        )
