import random, sys
from PIL import Image, ImageEnhance, ImageFilter, ImageDraw, ImageFont, ImageOps
from skimage import color
import numpy as np

class RandomTransform:
    def __init__(self, img, mode):
        self.img = img
        self.x_size, self.y_size = img.size
        self.mode = mode
        self.min_size = 500
        self.max_size_ratio = 1.75

    def valid_size(self):
        if self.x_size < self.min_size or self.y_size < self.min_size:
            print >>sys.stderr, "Image is smaller than", self.min_size
            return False
        if float(self.x_size) / self.y_size > self.max_size_ratio or float(self.y_size) / self.x_size > self.max_size_ratio:
            print >>sys.stderr, "Image proportion is bigger than", self.max_size_ratio
            return False
        return True

    def update(self, img):
        self.img = img
        self.x_size, self.y_size = img.size

    def mirror(self):
        if self.mode == 'pos':
            flip = Image.FLIP_LEFT_RIGHT
        else:
            flip = Image.FLIP_TOP_BOTTOM

        print >>sys.stderr, "Mirror transform", flip
        return self.img.transpose(flip)

    def borderline(self):
        lower_bound = 0.02
        thresh_pos = 0.03
        width = int(min(self.x_size, self.y_size) * (random.random() * (thresh_pos - lower_bound) + lower_bound))

        print >>sys.stderr, "Adding border", width
        return ImageOps.expand(self.img, border=width, fill='black')

    def rotate(self):
        lower_bound = 2.5
        thresh_pos = 3
        thresh_neg = 30
        upper_bound = 50

        if self.mode == 'pos':
            angle = random.randrange(-1, 2, 2) * (random.random() * (thresh_pos - lower_bound) + lower_bound)
        else:
            angle = random.randrange(-1, 2, 2) * (random.random() * (upper_bound - thresh_neg) + thresh_neg)

        rotated = self.img.rotate(angle, resample=Image.BICUBIC)
        xa = int(self.x_size * thresh_pos / 100.0)
        xb = self.x_size - xa
        ya = int(self.y_size * thresh_pos / 100.0)
        yb = self.y_size - ya
        cropped = rotated.crop((xa, ya, xb, yb))

        print >>sys.stderr, "Rotate transform", int(angle)
        return cropped

    def shrink(self):
        thresh_pos = 0.2
        vertical = bool(random.randrange(0, 2))
        scale = random.random() * 2 * thresh_pos + 1 - thresh_pos

        if vertical:
            dst_size = (self.x_size, int(self.y_size * scale))
        else:
            dst_size = (int(self.x_size * scale), self.y_size)

        print >>sys.stderr, "Shrink transform vertical flag", vertical, "scale", scale
        return self.img.resize(dst_size, resample=Image.BICUBIC)

    def crop(self):
        thresh_pos = 0.1
        thresh_neg = 0.45
        upper_bound = 0.6

        if self.mode == 'pos':
            sx = int(self.x_size * (1 - random.random() * thresh_pos))
            sy = int(self.y_size * (1 - random.random() * thresh_pos))
        else:
            sx = int(self.x_size * (1 - random.random() * (upper_bound - thresh_neg) - thresh_neg))
            sy = int(self.y_size * (1 - random.random() * (upper_bound - thresh_neg) - thresh_neg))

        x = int(random.random() * (self.x_size  - sx))
        y = int(random.random() * (self.y_size  - sy))

        print >>sys.stderr, "Crop transform", float(sx) / self.x_size, float(sy) / self.y_size, float(x) / self.x_size, float(y) / self.y_size
        return self.img.crop((x, y, x + sx, y + sy))

    def resize(self):
        thresh_pos = 3
        upper_bound = 4
        factor = random.random() * (upper_bound - thresh_pos) + thresh_pos
        return self.img.resize((int(self.img.size[0] / factor), int(self.img.size[1] / factor)), resample=Image.BICUBIC)

    def hue(self):
        if self.mode == 'pos':
            lower = 0.1
            upper = 0.2
        else:
            lower = 0.6
            upper = 0.7

        coeff = rand_float_around_one(lower, upper)

        hsv_img_arr = np.array(self.img.convert('HSV'))
        coeff_matrix = np.repeat(np.repeat([[[coeff, 1.0, 1.0]]], self.y_size, axis=0), self.x_size, axis=1)
        enhanced_img_arr = np.uint8(np.minimum(np.floor(np.multiply(hsv_img_arr, coeff_matrix)), 255))

        print >>sys.stderr, "Hue transform", coeff
        return Image.fromarray(enhanced_img_arr, 'HSV').convert('RGB')

    def color(self):
        if self.mode == 'pos':
            lower = 0.03
            upper = 0.04
        else:
            lower = 0.3
            upper = 0.4

        coeff = rand_float_around_one(lower, upper)

        hsv_img_arr = np.array(self.img)
        coeff_matrix = np.repeat(np.repeat([[[coeff, 1.0 / coeff, coeff]]], self.y_size, axis=0), self.x_size, axis=1)
        enhanced_img_arr = np.uint8(np.minimum(np.floor(np.multiply(hsv_img_arr, coeff_matrix)), 255))

        print >>sys.stderr, "color transform", coeff
        return Image.fromarray(enhanced_img_arr)

    def add_text(self):
        start_font_pt = 40
        font_name = "Avenir.ttc"
        all_texts = ["vk.com", "ok.ru", "youtube.com", "yandex.ru", "new.vk.com", "go.mail.ru", "google.ru", "e.mail.ru", "yandex.ua", "avito.ru", "facebook.com", "fotostrana.ru", "google.com.ua", "mail.ru", "instagram.com", "olx.ua", "ru.aliexpress.com", "yandex.by", "my.mail.ru", "love.mail.ru", "google.com", "google.by", "yandex.kz", "ru.wargaming.net", "news.mail.ru", "auto.ria.com", "ru.wikipedia.org", "kinogo.club", "kufar.by", "web.facebook.com", "ru.bongacams.com", "search.webalta.ru", "rst.ua", "mamba.ru", "mail.yandex.ru", "tabor.ru", "ask.fm", "m.vk.com", "auto.drom.ru", "market.yandex.ru", "news.yandex.ru", "google.kz", "2gis.ru", "spaces.ru", "node1.online.sberbank.ru", "24video.xxx", "kinogo.co", "otvet.mail.ru", "translate.yandex.ru", "auto.ru", "ex.ua", "accounts.google.com", "glurl.ru", "twitter.com", "drive2.ru", "heroeswm.ru", "wildberries.ru", "gismeteo.ru", "olx.kz", "kolesa.kz", "node2.online.sberbank.ru", "seasonvar.ru", "traffic-media.co", "xvideos.com", "twitch.tv", "mail.google.com", "am15.net", "vseigru.net", "ukr.net", "time-to-read.ru", "worldoftanks.ru", "steamcommunity.com", "ss.lv", "my-hit.org", "mferma.ru", "ficbook.net", "pornoonlain.tv", "fs.to", "sinoptik.ua", "imgsrc.ru", "yabs.yandex.ru", "kinopoisk.ru", "yadi.sk", "mxttrf.com", "zaycev.net", "rutracker.org", "gidonline.club", "sport.mail.ru", "cityadspix.com", "rozetka.com.ua", "pornhub.com", "rambler.ru", "baskino.club", "cloud.mail.ru", "wf.mail.ru", "999.md", "news.rambler.ru", "ebay.com", "glclck.ru", "parimatch.com"]
        all_text_colors = [(200, 0, 0),
                            (0, 200, 0),
                            (0, 0, 200),
                            (255, 50, 50),
                            (50, 255, 50),
                            (50, 50, 255),
                            (50, 50, 50)]

        text = random.choice(all_texts)
        text_color = random.choice(all_text_colors)

        new_img = self.img.copy()
        draw = ImageDraw.Draw(new_img)

        font = ImageFont.truetype(font_name, start_font_pt)
        font_size = font.getsize(text)
        font = ImageFont.truetype(font_name, int(float(start_font_pt) * self.x_size * 3 / 4 / font_size[0]))

        lower_bound = self.y_size / 10
        upper_bound = self.y_size / 10 * 9 - font.getsize(text)[1]
        text_position = (self.x_size / 8, int(random.random() * (upper_bound - lower_bound) + lower_bound))

        draw.text(text_position, text, text_color, font=font)

        print >>sys.stderr, "Add text transform", text, text_position, text_color
        return new_img

    def generic_enhance_wrapper(self, method, lower, upper):
        def wrapper():
            param = rand_float_around_one(lower, upper)
            print >>sys.stderr, "Generic enhance", method, param
            return method(self.img).enhance(param)
        return wrapper

    def generic_filter_wrapper(self, method, lower, upper, odd=True):
        def wrapper():
            if odd:
                param = rand_odd(lower, upper)
            else:
                param = random.random() * (upper - lower) + lower
            print >>sys.stderr, "Generic filter", method, param
            return self.img.filter(method(param))
        return wrapper

    def run_sequence(self, n=2):
        if self.mode == 'pos':
            transformations = [[self.mirror],
                    [self.shrink,
                    self.rotate,
                    self.crop],
                    [self.color,
                    self.hue],
                    [self.resize],
                    [self.borderline],
                    [self.generic_enhance_wrapper(ImageEnhance.Brightness, 0.05, 0.15)],
                    [self.generic_enhance_wrapper(ImageEnhance.Contrast, 0.05, 0.15)],
                    [self.generic_filter_wrapper(ImageFilter.GaussianBlur, 1, 1)]]
        else:
            transformations = [
                    [self.crop]
                    ]
                    #[self.rotate,
                    #[self.color],
                    #[self.generic_enhance_wrapper(ImageEnhance.Brightness, 0.9, 1.1)],
                    #[self.generic_enhance_wrapper(ImageEnhance.Contrast, 1.1, 1.3)],
                    #[self.generic_filter_wrapper(ImageFilter.MedianFilter, 5, 5),
                    #self.generic_filter_wrapper(ImageFilter.UnsharpMask, 7, 7),
                    #self.generic_filter_wrapper(ImageFilter.GaussianBlur, 3, 3)],
                    #[self.generic_filter_wrapper(ImageFilter.MinFilter, 5, 5),
                    #self.generic_filter_wrapper(ImageFilter.MaxFilter, 5, 5)],
                    #[self.add_text]

        image_backup = self.img

        random.shuffle(transformations)
        for t_group in transformations[:n]:
            self.update(random.choice(t_group)())

        result = self.img
        self.update(image_backup)
        return result


def rand_odd(lower, upper):
    return random.choice([ x for x in range(lower, upper + 1) if x % 2 == 1 ])


def rand_float_around_one(lower, upper):
    factor = random.random() * (upper - lower) + lower + 1.0
    if random.randrange(0, 2) == 0:
        return factor
    return 1.0 / factor


def concat_images(images, border=0):
    w = sum(i.size[0] for i in images)
    mh = max(i.size[1] for i in images)

    result = Image.new("RGB", (w + border * (len(images) + 1), mh))
    black = Image.new('RGB', (border, mh))

    x = 0
    if border != 0:
        result.paste(black, (x, 0))
        x += black.size[0]

    for i in images:
        result.paste(i, (x, 0))
        x += i.size[0]
        if border != 0:
            result.paste(black, (x, 0))
            x += black.size[0]

    return result
