from functools import partial
import math
from datetime import datetime

class Rule(object):
    def __init__(self, factor, name, func):
        self.factor = factor
        self.name = name
        self.func = func

    def __call__(self, img):
        if self.factor not in img:
            return False
        return self.func(img[self.factor])

    def __eq__(self, other):
        return self.name == other.name

    def __hash__(self):
        return hash(self.name)

    def __str__(self):
        return self.name

    def __repr__(self):
        return self.name


###################################################################################
#### functions to be partially applied and used as factor splitting rules #########

def get_factor_in_arr_func(factor, values):
    return factor in values


def get_factor_greater_val_func(factor, val):
    try:
        factor = float(factor)
    except:
        return False
    return factor > val


###########################################################
###### data preprocessing and rules generation ############

def add_calculated_factors_to_images(images, flat=True, custom_touch_viewport=None):
    if flat:
        for img in images:
            if img.get('query_device', 'DESKTOP') == 'DESKTOP':
                img['viewport'] = desktop_viewport_scale(img.get('width', None), img.get('height', None))
            else:
                img['viewport'] = touch_viewport_scale(img.get('width', None), img.get('height', None), custom_touch_viewport)

            img['square'] = square_tanh_scale(img.get('width', None), img.get('height', None), 2.0)

            #utility_approx = utility_approx_factor(img)
            #if utility_approx is not None:
            #    img['utility_approx'] = utility_approx

            #img['min_defect'] = min_defect_factor(img)
            #img['corsa-xr'] = corsa_xr_metric(img)
            #img['corsa-stove-dbd-precise'] = corsa_stove_dbd_precise_metric(img)
            #img['corsa-stove-dbd-reduced'] = corsa_stove_dbd_reduced_metric(img)
            #img['corsa-stove-dbd-noutil-reduced'] = corsa_stove_dbd_noutil_reduced_metric(img)
            img['light_proxima'] = light_proxima_metric(img)
    else:
        for query in images:
            for img in images[query]:
                if img.get('query_device', 'DESKTOP') == 'DESKTOP':
                    img['viewport'] = desktop_viewport_scale(img.get('width', None), img.get('height', None))
                else:
                    img['viewport'] = touch_viewport_scale(img.get('width', None), img.get('height', None), custom_touch_viewport)

                img['square'] = square_tanh_scale(img.get('width', None), img.get('height', None), 2.0)

                #utility_approx = utility_approx_factor(img)
                #if utility_approx is not None:
                #    img['utility_approx'] = utility_approx

                #img['min_defect'] = min_defect_factor(img)
                #img['corsa-xr'] = corsa_xr_metric(img)
                #img['corsa-stove-dbd-precise'] = corsa_stove_dbd_precise_metric(img)
                #img['corsa-stove-dbd-reduced'] = corsa_stove_dbd_reduced_metric(img)
                #img['corsa-stove-dbd-noutil-reduced'] = corsa_stove_dbd_noutil_reduced_metric(img)
                img['light_proxima'] = light_proxima_metric(img)


def generate_split_rules():
    rules = []

    def category_factor_rules(factor, name, values):
        factor_rules = []
        for i in range(len(values) - 1):
            threshold_value = values[i]
            rule_name = '{} > {}'.format(name, threshold_value)
            factor_rules.append(Rule(factor, rule_name, partial(get_factor_in_arr_func, values=values[i+1:])))
        return factor_rules

    def float_factor_rules(factor, name, bucket_delims):
        factor_rules = []
        for delim in bucket_delims:
            rule_name = '{} > {}'.format(name, delim)
            factor_rules.append(Rule(factor, rule_name, partial(get_factor_greater_val_func, val=delim)))
        return factor_rules

    #rules.extend(category_factor_rules('page_binary_relevance', 'page_rel', ['IRRELEVANT', 'RELEVANT']))

    rules.extend(float_factor_rules('visual_quality', 'visual_quality', [x / 10.0 for x in range(1, 10)]))
    rules.extend(float_factor_rules('viewport', 'viewport', [0.256, 0.379, 0.472, 0.53, 0.62, 0.702, 0.752, 0.795, 0.85]))
    rules.extend(float_factor_rules('square', 'square', [0.349, 0.496, 0.623, 0.744, 0.856, 0.907, 0.951, 0.976, 0.9995]))
    rules.extend(float_factor_rules('aestetics', 'aestetics', [x / 10.0 for x in range(1, 10)]))
    #rules.extend(float_factor_rules('light_proxima', 'light_proxima', [0.05, 0.206, 0.42, 0.5, 0.505, 0.62, 0.759, 0.921, 0.9965]))
    rules.extend(float_factor_rules('light_proxima', 'light_proxima', [0.2499, 0.25, 0.47, 0.63, 7499, 0.75, 0.957, 0.99]))
    rules.extend(float_factor_rules('dbd_signal_v3', 'dbd_signal', [x / 10.0 for x in range(1, 10)]))

    return rules, category_factor_rules('relevance', 'relevance', ['IRRELEVANT', 'RELEVANT_MINUS', 'RELEVANT_PLUS'])


##################################################
############### factor calculations ##############

REL_MAP = {
    'RELEVANT_PLUS': 1.0,
    'RELEVANT_MINUS': 0.75,
    'IRRELEVANT': 0,
}


def desktop_viewport_scale(iw, ih):
    if not iw or not ih:
        return 0.0
    screens = [
        (1366, 768,  0.292),
        (1920, 1080, 0.188),
        (1280, 1024, 0.102),
        (1600, 900,  0.071),
        (1440, 900,  0.040),
        (1280, 800,  0.033),
        (1280, 800,  0.032),
        (1024, 768,  0.030),
        (1536, 864,  0.030),
        (1680, 1050, 0.028),
        (1280, 720,  0.019),
        (1280, 720,  0.019),
        (1360, 768,  0.016),
    ]

    coef = 1.0 / sum([x[2] for x in screens])
    result = 0
    for screen in screens:
        fw = screen[0] - 45*2 - 340
        fh = screen[1] - 20 - 26 - 101
        result += screen[2] * coef * (( (iw * ih) / float(fw * fh) ) * min(fw/float(iw), fh/float(ih), 1)**2)
    return result


def touch_viewport_scale(iw, ih, custom_viewport):
    if not iw or not ih:
        return 0.0
    sw = 750
    sh = 1334
    if custom_viewport is not None:
        sh *= custom_viewport
    return ((iw * ih) / float(sw * sh)) * min(sw / float(iw), sh / float(ih), 1) ** 2


def square_scale(iw, ih, growth, shift):
    if not iw or not ih:
        return 0.0
    square = iw * ih
    return  1.0 / (1 + math.exp(growth * (-square) / 1000000.0 + shift))


def square_tanh_scale(iw, ih, growth):
    if not iw or not ih:
        return 0.0
    square = iw * ih
    return math.tanh(growth * square / 1000000.0)


def utility_approx_factor(img):
    if 'images_vq3_v2_avg' not in img:
        return None
    if 'aestetics_mean5' not in img:
        return None
    if 'image_height' not in img or 'image_width' not in img:
        return None
    if 'relevance' not in img:
        return None

    boost = 0.0
    boost += 0.4 * img['images_vq3_v2_avg']
    boost += 0.16 * img['aestetics_mean5']
    boost += 0.1 * square_tanh_scale(img['image_width'], img['image_height'], 2)
    boost += 0.07 * desktop_viewport_scale(img['image_width'], img['image_height'])
    boost += 0.27 * REL_MAP[img['relevance']]

    return boost


def diff_month(d1, d2):
    diff = (d1.year - d2.year) * 12 + d1.month - d2.month
    return max(diff, 0)


def calc_timestamp_factor(image_create_time, current_date, steps, need_fresh=True):
    if not image_create_time or math.isnan(image_create_time) or not need_fresh:
        return 0.0

    image_create_time = int(image_create_time)
    if image_create_time < 900000000:
        return 0.0

    diff = diff_month(current_date, datetime.fromtimestamp(image_create_time))
    for month in sorted(steps):
        if diff <= month:
            return steps[month]
    return 0.0


def grayscale_avg_scale(mark):
    if mark == "NOT_JUDGED" or mark == "-1" or mark == "" or mark == -1:
        return 0.0

    assert (float(mark) >= 0), "Grayscale must be a positive float number: {}".format(mark)

    avg_gray_deviation = float(mark)

    if avg_gray_deviation > 1000:
        return 1.0
    if avg_gray_deviation > 600:
        return 0.8
    if avg_gray_deviation > 400:
        return 0.6
    if avg_gray_deviation > 250:
        return 0.4
    if avg_gray_deviation > 150:
        return 0.2
    if avg_gray_deviation > 100:
        return 0.1
    return 0.0


def light_proxima_metric(img):
    PAGE_REL_MAP = {
            "RELEVANT": 1.0,
            "IRRELEVANT": 0.0,
            "_404": 0.0
            }
    return ((img.get('kernel', 0.0) or 0.0) + PAGE_REL_MAP[img.get('page_relevance', 'IRRELEVANT')]) / 2.0


def corsa_xr_metric(img, boost_weight=0.3):
    if img.get('relevance') == 'IRRELEVANT':
        return 0

    w = {
        'util': 0.5,
        'vp': 0.11421,
        'aes': 0.09114,
        'page_rel': 0.04307,
        'size': 0.08142,
        'color': 0.02186,
        'kernel': 0.03267,
        'vq': 0.11563,
    }

    boost = 0
    boost += img.get('utility_v2_avg', 0) * w['util']
    boost += img.get('viewport', 0) * w['vp']
    boost += img.get('aestetics_mean5', 0) * w['aes']
    if img.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        boost += w['page_rel']
    boost += img.get('square_tanh_2_0', 0) * w['size']
    boost += grayscale_avg_scale(img.get('avatars_avg_gray_deviation', 0)) * w['color']
    boost += (img.get('biz_kernel_quantile', 0) or 0) * w['kernel']
    boost += img.get('images_vq3_v2_avg', 0) * w['vq']
    #boost *= img.get('min_defect', 0.0)

    return boost #(REL_MAP.get(img.get('relevance'), 0) + boost_weight * boost) / (1.0 + boost_weight)


def corsa_stove_dbd_noutil_reduced_metric(img, boost_weight=0.3):
    if img.get('relevance') == 'IRRELEVANT':
        return 0

    w = {
        'util': 0.0,
        'vp': 0.09,
        'aes': 0.16,
        'page_rel': 0.02,
        'size': 0.09,
        'color': 0.0,
        'kernel': 0.0,
        'vq': 0.13,
        'dbd_signal_const_rank': 0.51
    }

    boost = 0
    boost += img.get('utility_v2_avg', 0) * w['util']
    boost += img.get('viewport', 0) * w['vp']
    boost += img.get('aestetics_mean5', 0) * w['aes']
    if img.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        boost += w['page_rel']
    boost += img.get('square_tanh_2_0', 0) * w['size']
    boost += grayscale_avg_scale(img.get('avatars_avg_gray_deviation', 0)) * w['color']
    boost += (img.get('biz_kernel_quantile', 0) or 0) * w['kernel']
    boost += img.get('images_vq3_v2_avg', 0) * w['vq']
    boost += img.get('dbd_signal_const_rank', 0) * w['dbd_signal_const_rank']
    #boost *= img.get('min_defect', 0.0)

    return boost #(REL_MAP.get(img.get('relevance'), 0) + boost_weight * boost) / (1.0 + boost_weight)


def corsa_stove_dbd_precise_metric(img, boost_weight=0.3):
    if img.get('relevance') == 'IRRELEVANT':
        return 0

    w = {
        'util': 0.21978,
        'vp': 0.05575,
        'aes': 0.00042,
        'page_rel': 0.00174,
        'size': 0.00262,
        'color': 0.01937,
        'kernel': 0.00015,
        'vq': 0.10258,
        'dbd_signal_const_rank': 0.19107
    }

    boost = 0
    boost += img.get('utility_v2_avg', 0) * w['util'] / (1.0 - 0.40654)
    boost += img.get('viewport', 0) * w['vp'] / (1.0 - 0.40654)
    boost += img.get('aestetics_mean5', 0) * w['aes'] / (1.0 - 0.40654)
    if img.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        boost += w['page_rel'] / (1.0 - 0.40654)
    boost += img.get('square_tanh_2_0', 0) * w['size'] / (1.0 - 0.40654)
    boost += grayscale_avg_scale(img.get('avatars_avg_gray_deviation', 0)) * w['color'] / (1.0 - 0.40654)
    boost += (img.get('biz_kernel_quantile', 0) or 0) * w['kernel'] / (1.0 - 0.40654)
    boost += img.get('images_vq3_v2_avg', 0) * w['vq'] / (1.0 - 0.40654)
    boost += img.get('dbd_signal_const_rank', 0) * w['dbd_signal_const_rank'] / (1.0 - 0.40654)
    #boost *= img.get('min_defect', 0.0)

    return boost #(REL_MAP.get(img.get('relevance'), 0) + boost_weight * boost) / (1.0 + boost_weight)


def corsa_stove_dbd_reduced_metric(img, boost_weight=0.3):
    if img.get('relevance') == 'IRRELEVANT':
        return 0

    w = {
        'util': 0.35,
        'vp': 0.06,
        'aes': 0,
        'page_rel': 0,
        'size': 0.06,
        'color': 0.03,
        'kernel': 0,
        'vq': 0.18,
        'dbd_signal_const_rank': 0.32
    }

    boost = 0
    boost += img.get('utility_v2_avg', 0) * w['util']
    boost += img.get('viewport', 0) * w['vp']
    boost += img.get('aestetics_mean5', 0) * w['aes']
    if img.get('page_binary_relevance', 'NO_MARK') == 'RELEVANT':
        boost += w['page_rel']
    boost += img.get('square_tanh_2_0', 0) * w['size']
    boost += grayscale_avg_scale(img.get('avatars_avg_gray_deviation', 0)) * w['color']
    boost += (img.get('biz_kernel_quantile', 0) or 0) * w['kernel']
    boost += img.get('images_vq3_v2_avg', 0) * w['vq']
    boost += img.get('dbd_signal_const_rank', 0) * w['dbd_signal_const_rank']
    #boost *= img.get('min_defect', 0.0)

    return boost #(REL_MAP.get(img.get('relevance'), 0) + boost_weight * boost) / (1.0 + boost_weight)


def min_defect_factor(img):
    w = {
            'images_vq3_v2_avg': (0.5, 0.8),
            'aestetics_mean5': (0.5, 0.8),
            'utility_v2_avg': (0.6, 0.8),
            'square_tanh_2_0': (0.7, 0.9),
            'viewport': (0.6, 0.7)
            }

    if img.get('relevance', 'IRRELEVANT') == 'IRRELEVANT':
        return 0.0

    wideness = img.get('wideness', 'wideness_3')
    if wideness == 'wideness_3':
        return 1.0

    defect = 1.0
    for signal in w:
        if signal not in img:
            return 0.0

        a, b = w[signal]
        defect = min(defect, lin_transform(img[signal], a, b))

    return defect


def corsa_dt_metric(img, boost_weight=0.3):
    if img.get('relevance') == 'IRRELEVANT':
        return 0

    boost = img.get('dt_score', 0.0)

    return (REL_MAP.get(img.get('relevance'), 0) + boost_weight * boost) / (1.0 + boost_weight)


def lin_transform(num, min_val, max_val):
    if num <= min_val:
        return min_val
    if num >= max_val:
        return max_val
    return max(0.0, min(1.0, (num - min_val) / (max_val - min_val)))
