"""
https://github.yandex-team.ru/kovchiy/colorwiz/blob/master/colorwiz.js
https://github.yandex-team.ru/moblauncher/yandex_launcher/blob/next-release/src/com/yandex/launcher/util/RecColors.java
"""

# !/usr/bin/env python
# -*- coding: utf-8 -*-

import math

import numpy as np

from color_info import ColorInfo

COLOR_TRANSPARENT = 1
COLOR_WHITE = 2
COLOR_BLACK = 4
COLOR_WHITE_OR_BLACK = COLOR_WHITE | COLOR_BLACK | COLOR_TRANSPARENT

COLOR_OPAQUE = 0xFF


def get_background_color(image):
    width, height, channel_count = image.shape
    if width < 4 or height < 4:
        return

    color = image[1, 1]
    if channel_count > 3 and color[3] < 128:
        return

    conditions = [
        has_rect_color(image, 1, 1, width - 2, 1, color),
        has_rect_color(image, 1, height - 2, width - 2, 1, color),
        has_rect_color(image, 1, 1, 1, height - 2, color),
        has_rect_color(image, width - 2, 1, 1, height - 2, color),
    ]

    if all(conditions):
        return ColorInfo(color)


def detect_uniform_color(colors, level=0.96, delta=2):
    color_counts = np.apply_along_axis(np.bincount, 0, colors[:, :3], minlength=256)
    uniform_color = np.argmax(color_counts, axis=0)
    color_sum = np.zeros((257, 3), dtype=np.uint32)
    color_sum[1:, :] = np.cumsum(color_counts, axis=0)
    range_start = np.clip(uniform_color - delta, 0, 256)
    range_end = np.clip(uniform_color + delta + 1, 0, 256)
    sum_near = color_sum[range_end, np.arange(3)] - color_sum[range_start, np.arange(3)]
    # noinspection PyTypeChecker
    return np.all(sum_near > int(len(colors) * level))


def calculate(image):
    assert image.mode in ('RGB', 'RGBA'), 'Only RGB and RGBA image modes are supported'
    image = np.asarray(image)
    image_width, image_height, _ = image.shape
    # Check if the bitmap has uniform background and use it afterwards
    background_color = get_background_color(image)

    # Count pixels w/o transparent, white, black
    color_types = get_color_type(image)
    not_transparent = (color_types & COLOR_TRANSPARENT == 0)
    not_transparent_black_or_white = (color_types & COLOR_WHITE_OR_BLACK == 0)

    # Find 4 initial pixels>
    use_wb = np.sum(not_transparent_black_or_white) < 0.12 * image_width * image_height
    mask = not_transparent if use_wb else not_transparent_black_or_white
    initial_points, dark_count = get_initial_color_points(image[mask])

    is_dark = background_color.is_dark() if background_color is not None else (dark_count > 1)
    color_scheme_cls = DarkScheme if is_dark  else LightScheme
    color_scheme = color_scheme_cls(background_color, initial_points)
    return color_scheme.get_colors()


def get_initial_color_points(colors):
    counter, _ = colors.shape
    color_info = []
    dark_count = 0

    side = int(math.sqrt(counter))
    if side > 1:
        y0 = side / 10
        y1 = side / 2 + 1
        x0 = side / 4
        x1 = 3 * side / 4

        points = np.array([
            y0 * side + x0,
            y0 * side + x1,
            y1 * side + x0,
            y1 * side + x1
        ])

        for color in colors[points, :]:
            color = ColorInfo(color)
            if color.is_dark():
                dark_count += 1
            color_info.append(color)
        color_info.sort(key=ColorInfo.get_adapted_l)
    else:
        color_info = [ColorInfo((0xff, 0xff, 0xff, 0xff))] * 4
    return color_info, dark_count


class ColorScheme(object):
    BUTTON_DELTA = 0.2
    BACKGROUND_DELTA = 0.15
    BUTTON_TEXT_DELTA = 0.6

    def __init__(self, background_color, initial_points):
        self.background_color = background_color
        self.initial_points = initial_points
        self.process_colors()

    def process_colors(self):
        self.card_color = self.get_card_color()
        self.button_color = self.get_button_color(self.card_color)
        self.text_color = self.get_text_color(self.card_color)
        self.button_text_color = self.get_button_text_color(self.button_color)

    def get_button_text_color(self, button_color):
        button_text_color = ColorInfo(button_color)
        if button_text_color.is_dark():
            button_text_color.hsl[2] = min(button_color.hsl[2] + self.BUTTON_TEXT_DELTA, 1.0)
        else:
            button_text_color.hsl[2] = max(button_color.hsl[2] - self.BUTTON_TEXT_DELTA, 0.0)
        return button_text_color

    def get_colors(self):
        return {
            "card_background": str(self.card_color),
            "card_text": str(self.text_color),
            "button_background": str(self.button_color),
            "button_text": str(self.button_text_color)
        }

    def change_background_color(self, card_color):
        if card_color.hsl[2] >= 0.1 + self.BACKGROUND_DELTA:
            card_color.hsl[2] -= self.BACKGROUND_DELTA
        else:
            card_color.hsl[2] += self.BACKGROUND_DELTA
        return card_color


class LightScheme(ColorScheme):
    def get_card_color(self):
        if self.background_color is not None:
            card_color = ColorInfo(self.background_color)
        else:
            card_color = ColorInfo(self.initial_points[3])
        return self.change_background_color(card_color)

    def get_text_color(self, card_color):
        text_color = ColorInfo(self.initial_points[0])
        if abs(text_color.hsl[2] - card_color.hsl[2]) < 0.6:
            text_color.hsl[2] = max(text_color.hsl[2] - 0.5, 0.0)
        return text_color

    def get_button_color(self, card_color):
        button_color = ColorInfo(self.initial_points[2])
        if abs(button_color.hsl[2] - card_color.hsl[2]) < self.BUTTON_DELTA:
            button_color = ColorInfo(self.initial_points[1])
            if abs(button_color.hsl[2] - card_color.hsl[2]) < self.BUTTON_DELTA:
                button_color.hsl[2] = (button_color.hsl[2] - self.BUTTON_DELTA) \
                    if button_color.hsl[2] >= self.BUTTON_DELTA else button_color.hsl[2] + self.BUTTON_DELTA
        return button_color


class DarkScheme(ColorScheme):
    def get_card_color(self):
        c0, c1 = self.initial_points[0:2]
        if self.background_color is not None:
            card_color = ColorInfo(self.background_color)
        else:
            card_color = ColorInfo(c0 if c0.hsl[2] >= 0.1 else c1)
        if card_color.hsl[2] < 0.1:
            card_color.hsl[2] = 0.1
        return self.change_background_color(card_color)

    def get_text_color(self, card_color):
        c3 = self.initial_points[3]
        text_color = ColorInfo(c3)
        if abs(text_color.hsl[2] - card_color.hsl[2]) < 0.6:
            text_color.hsl[2] = min(text_color.hsl[2] + 0.5, 1.0)
        return text_color

    def get_button_color(self, card_color):
        c1, c2 = self.initial_points[1:3]
        button_color = ColorInfo(c1)
        if abs(button_color.hsl[2] - card_color.hsl[2]) < self.BUTTON_DELTA:
            button_color = ColorInfo(c2)
            if abs(button_color.hsl[2] - card_color.hsl[2]) < self.BUTTON_DELTA:
                button_color.hsl[2] = (button_color.hsl[2] + self.BUTTON_DELTA) \
                    if button_color.hsl[2] <= 1.0 - self.BUTTON_DELTA else button_color.hsl[2] - self.BUTTON_DELTA
        return button_color


def get_color_type(image, black_threshold=5, white_threshold=250, transparent_threshold=128):
    width, height, dim = image.shape
    color_types = np.zeros((width, height), dtype=np.uint8)
    color_types[np.all((image[:, :, :3] < black_threshold), axis=2)] = COLOR_BLACK
    color_types[np.all((image[:, :, :3] > white_threshold), axis=2)] = COLOR_WHITE
    if dim > 3:
        color_types[image[:, :, 3] < transparent_threshold] = COLOR_TRANSPARENT
    return color_types


def abs_difference(a, b):
    """ Overflow-safe abs(a-b) """
    res = a - b
    mask = (a < b)
    res[mask] = (b - a)[mask]
    return res


def has_rect_color(image, x0, y0, w, h, color):
    # noinspection PyTypeChecker
    rect_side = image[x0:x0 + w, y0:y0 + h]
    return ~np.any(abs_difference(rect_side, color) > 3)
