import random
import logging
import math
import itertools
import cStringIO

import numpy as np
from PIL import Image, ImageDraw, ImageFont

from captcha.generation.image_generator.noise import whitenoise


class FailureToFit(BaseException):
    pass


def get_font(context, char, scale=1):
    # TODO: caching
    char_config = context['config']['character_rendering']
    size = random.choice(char_config['font_sizes'])
    size = int(size*scale)

    fontfile = random.choice(char_config['fonts'])
    fontdata = context['resources'].get(fontfile)
    return ImageFont.truetype(cStringIO.StringIO(fontdata), size)


def text_stroke(context, draw, coords, text, font=None, fillcolor='white', strokecolor='black', scale=1):
    stroke = context['config']['character_rendering'].get('stroke_width', 1)
    stroke = int(math.ceil(stroke*scale))
    x, y = coords
    for dx in xrange(-stroke, stroke+1):
        for dy in xrange(-stroke, stroke+1):
            if dx == dy == 0:
                continue
            draw.text((x + dx, y + dy), text, font=font, fill=strokecolor)
    draw.text((x, y), text, font=font, fill=fillcolor)


def get_char_size(context, font, char):
    w, h = font.getsize(char)
    stroke_width = context['config'].get('stroke_width', 1)
    return w+2*stroke_width, h+2*stroke_width


def draw_sun_dashes(context, charimg, scale=1):
    # logic ported from Perl implementation
    R = 30 * scale
    xc = (charimg.size[0]/2)
    yc = (charimg.size[1]/2)*1.23
    Gr = -0.3
    shift = True
    draw = ImageDraw.Draw(charimg)
    while True:
        if (Gr > 3.4 and shift):
            yc -= 14 * scale
            shift = False
            Gr -= 0.5
        if Gr > 6.48:
            break
        points = [(xc, yc)]
        Gr += 0.6 + random.random()*0.4
        points.append((R*math.cos(Gr) + xc, R*math.sin(Gr) + yc))
        Gr += 0.1 + random.random()*0.05
        points.append((R*math.cos(Gr) + xc, R*math.sin(Gr) + yc))
        fill = (255, 0)
        draw.polygon(points, fill=fill, outline=fill)


def render_character(context, char):
    char_config = context['config']['character_rendering']

    scale = context['config']['scale']
    smooth_scaling = char_config.get('smooth_scaling', 1)

    font = get_font(context, char, scale=scale*smooth_scaling)
    size = get_char_size(context, font, char)
    charimg = Image.new('LA', size, color=(255, 0))
    draw = ImageDraw.Draw(charimg)
    text_stroke(context, draw, (0, 0), char, font=font, fillcolor=(255, 0), scale=smooth_scaling)
    if char_config.get('rotate_character'):
        rotate_config = char_config['rotate_character']
        angle = random.uniform(rotate_config['degrees_min'], rotate_config['degrees_max'])
        charimg = charimg.rotate(angle, resample=Image.BICUBIC, expand=True)
    if char_config.get('dashes'):
        draw_sun_dashes(context, charimg, scale=scale*smooth_scaling)

    if smooth_scaling == 1:
        return charimg
    else:
        w, h = charimg.size
        return charimg.resize((int(w/smooth_scaling), int(h/smooth_scaling)), resample=Image.BICUBIC)


def locate_borders(charimg, orientation):
    width, height = charimg.size
    assert orientation in ('x', 'y')

    def locate_border(target_range, orthogonal_size, pixelfunc):
        for a in target_range:
            for b in xrange(orthogonal_size):
                black, alpha = pixelfunc(a, b)
                if alpha:
                    return a
        assert False, "Image is empty"

    if orientation == 'x':
        pixelfunc = lambda a, b: charimg.getpixel((a, b))
        left = locate_border(xrange(width), height, pixelfunc)
        right = locate_border(xrange(width-1, -1, -1), height, pixelfunc)
        return left, right
    else:
        pixelfunc = lambda a, b: charimg.getpixel((b, a))
        top = locate_border(xrange(height), width, pixelfunc)
        bottom = locate_border(xrange(height-1, -1, -1), width, pixelfunc)
        return top, bottom


def crop_image(img):
    left, right = locate_borders(img, 'x')
    top, bottom = locate_borders(img, 'y')
    return img.crop((left, top, right+1, bottom+1))


def split_into_random_buckets(total, numbuckets, bucketmax):
    # TODO: optimize?
    assert 0 <= total <= numbuckets*bucketmax
    buckets = [0]*numbuckets
    incomplete_buckets = range(len(buckets))
    for _ in xrange(total):
        index = random.choice(incomplete_buckets)
        buckets[index] += 1
        if buckets[index] == bucketmax:
            incomplete_buckets.remove(index)
    return buckets


def render_initial_text(context, text):
    img = Image.new('L', (context['config']['width'], context['config']['height']), color=255)
    vert_shift_min = context['config'].get('text_vertical_shift', {}).get('min', 0)
    vert_shift_max = context['config'].get('text_vertical_shift', {}).get('max', 0)
    overlap_min = context['config'].get('overlap', {}).get('min', 0)
    overlap_max = context['config'].get('overlap', {}).get('max', 0)

    charimgs = []
    total_width = 0
    for char in text:
        charimg = crop_image(render_character(context, char))
        charimgs.append(charimg)
        cw, ch = charimg.size
        total_width += cw

    s = len(text) - 1  # spaces between characters
    if total_width - overlap_max*s > context['config']['width']:
        raise FailureToFit()

    effective_total_min_overlap = max(overlap_min*s, total_width - context['config']['width'])
    total_overlap = random.randint(effective_total_min_overlap, overlap_max*s)
    overlaps = split_into_random_buckets(total_overlap - overlap_min*s, s, overlap_max)
    overlaps = [o + overlap_min for o in overlaps]
    overlaps.append(0)

    assert sum(overlaps) == total_overlap

    x = 0
    for overlap, charimg in itertools.izip(overlaps, charimgs):
        assert 0 <= x < context['config']['width']

        cw, ch = charimg.size
        y = (context['config']['height'] - ch)/2
        y += random.randint(vert_shift_min, vert_shift_max)

        img.paste(charimg, (x, y), mask=charimg.split()[1])

        x += cw - overlap

    return img


def paste_image(context, cur_config, srcimg):
    data = context['resources'].get(cur_config['image'])
    img = Image.open(cStringIO.StringIO(data))
    return Image.alpha_composite(srcimg.convert('RGBA'), img)


def image_to_matrix(img):
    return 1.0 - np.array(img).transpose()/255.0


def matrix_to_image(matrix):
    matrix = ((1.0 - matrix.transpose())*255.0).astype('uint8')
    return Image.fromarray(matrix, 'L')


def shift_vector(vector, sh):
    if sh == 0:
        return vector

    size = vector.shape[0]
    result = np.zeros(size)
    if sh > 0:
        result[sh:] = vector[:size-sh]
        result[size-sh:] = .0
    elif sh < 0:
        result[-sh:] = vector[:size+sh]
        result[:-sh] = 0
    return result


def wave_distortion(cur_config, matrix):
    amp = random.uniform(cur_config['amplitude_min'], cur_config['amplitude_max'])
    length = random.uniform(cur_config['length_min'], cur_config['length_max'])
    w = 1.0 / length
    start = random.uniform(0, length)
    start = 0

    width, height = matrix.shape[0], matrix.shape[1]

    result = np.zeros((width, height))
    for x in xrange(width):

        inv_shift = -amp * np.sin(2.0*np.pi*(x+start) * w)
        inv_shift_floor = int(np.floor(inv_shift))
        d = inv_shift - inv_shift_floor  # interpolation coefficient
        assert .0 <= d <= 1.0

        vec1 = shift_vector(matrix[x, :], inv_shift_floor)
        vec2 = shift_vector(matrix[x, :], inv_shift_floor+1)
        result[x, :] = (1 - d)*vec1 + d*vec2

    return result


def interpolate_points(matrix, x, y):
    width, height = matrix.shape[0], matrix.shape[1]

    xf = int(np.floor(x))
    yf = int(np.floor(y))

    dx = x - xf
    dy = y - yf

    p00 = matrix[xf, yf]
    p10 = matrix[xf+1, yf] if xf+1 < width else .0
    p01 = matrix[xf, yf+1] if yf+1 < height else .0
    p11 = matrix[xf+1, yf+1] if xf+1 < width and yf+1 < height else .0

    result = (1-dx)*(1-dy)*p00
    result += dx*(1-dy)*p10
    result += (1-dx)*dy*p01
    result += dx*dy*p11
    return result


def swirl_distortion(cur_config, matrix):
    factor = random.uniform(cur_config['factor_min'], cur_config['factor_max'])

    width, height = matrix.shape[0], matrix.shape[1]
    cx = width/2.0
    cy = height/2.0

    result = np.zeros((width, height))

    for i in xrange(width):
        rel_x = i - cx
        for j in xrange(height):
            rel_y = cy - j
            if rel_x:
                orig_angle = np.arctan(float(abs(rel_y))/abs(rel_x))
                if rel_x > 0 and rel_y < 0:
                    orig_angle = 2.0*np.pi - orig_angle
                elif rel_x <= 0 and rel_y >= 0:
                    orig_angle = np.pi - orig_angle
                elif rel_x <= 0 and rel_y < 0:
                    orig_angle += np.pi
            elif rel_y >= 0:
                orig_angle = 0.5 * np.pi
            else:
                orig_angle = 1.5 * np.pi

            radius = np.sqrt(rel_x**2 + rel_y**2)
            new_angle = orig_angle + 1/(factor*radius + (4.0/np.pi))

            src_x = radius * np.cos(new_angle) + 0.5
            src_y = radius * np.sin(new_angle) + 0.5

            src_x += cx
            src_y = height - src_y - cy

            if src_x < 0:
                src_x = 0
            elif src_x >= width:
                src_x = width - 1

            if src_y < 0:
                src_y = 0
            elif src_y >= height:
                src_y = height - 1

            result[i, j] = interpolate_points(matrix, src_x, src_y)

    return result


def bezier_curve(points, steps=100):
    # Code adapted from https://stackoverflow.com/a/2292690
    def pascal_row(n):
        result = [1]
        x, numerator = 1, n
        for denominator in xrange(1, n/2+1):
            x *= numerator
            x /= denominator
            result.append(x)
            numerator -= 1

        if n & 1 == 0:
            # n is even
            result.extend(reversed(result[:-1]))
        else:
            result.extend(reversed(result))
        return result

    n = len(points)
    dim = len(points[0])
    combinations = pascal_row(n-1)
    result = []
    ts = (i/float(steps-1) for i in xrange(steps))
    for t in ts:
        tpowers = (t**i for i in xrange(n))
        upowers = reversed([(1-t)**i for i in xrange(n)])
        coefs = [c*a*b for c, a, b in itertools.izip(combinations, tpowers, upowers)]
        cur_point = [0]*len(points[0])
        for i in xrange(dim):
            cur_point[i] += sum(coef*point[i] for coef, point in itertools.izip(coefs, points))
        result.append(tuple(cur_point))

    return result


def select_bezier_control_points(size, cur_config):
    width, height = size

    parts = random.randint(cur_config['parts']['min'], cur_config['parts']['max'])
    xs = sorted([random.randint(0, width) for _ in xrange(parts)])
    ys = [random.randint(0, height) for _ in xrange(parts)]
    return zip(xs, ys)


def draw_bezier_line(cur_config, img):
    # This function uses a very sad hack to work around missing antialiasing in PIL.ImageDraw

    w, h = img.size
    scaling = cur_config.get('smooth_scaling', 2)

    points = select_bezier_control_points((w, h), cur_config)
    points_scaled = [(x*scaling, y*scaling) for x, y in points]
    curve = bezier_curve(points_scaled, steps=cur_config.get('steps', 100))

    lineimg = Image.new('LA', (w*scaling, h*scaling), color=(255, 0))
    draw = ImageDraw.Draw(lineimg)
    draw.line(curve, fill='black', width=cur_config.get('width', 1)*scaling)
    lineimg = lineimg.resize((w, h), resample=Image.BILINEAR)

    img.paste(lineimg, mask=lineimg.split()[1])

    return img


def add_white_noise(context, cur_config, img):
    scale = cur_config.get('scale', whitenoise.DEFAULT_SCALE)
    return whitenoise.add_noise_to_single_img(img, scale=scale)


def render_attempt(context, text):
    data = {
        'img': render_initial_text(context, text),
        'matrix': None
    }

    def convert_to_matrix():
        if data['matrix'] is not None:
            return
        data['matrix'] = image_to_matrix(data['img'])
        data['img'] = None

    def convert_to_image():
        if data['img']:
            return
        data['img'] = matrix_to_image(data['matrix'])
        data['matrix'] = None

    for transformation in context['config'].get('image_transformations', []):
        ttype = transformation['type']
        if ttype == 'swirl_distortion':
            convert_to_matrix()
            data['matrix'] = swirl_distortion(transformation, data['matrix'])
        elif ttype == 'wave_distortion':
            convert_to_matrix()
            data['matrix'] = wave_distortion(transformation, data['matrix'])
        elif ttype == 'draw_bezier_line':
            convert_to_image()
            data['img'] = draw_bezier_line(transformation, data['img'])
        elif ttype == 'paste_image':
            convert_to_image()
            data['img'] = paste_image(context, transformation, data['img'])
        elif ttype == 'add_white_noise':
            convert_to_image()
            data['img'] = add_white_noise(context, transformation, data['img'])
        else:
            raise RuntimeError('Unknown transformation: %s' % repr(ttype))

    convert_to_image()
    return data['img']


def render(context, text):
    attempts = 0
    while attempts < context['config'].get('max_attempts', 10):
        try:
            return render_attempt(context, text)
        except FailureToFit:
            logging.info("Failed to fit text %s into image", repr(text))
            attempts += 1

    raise RuntimeError("Couldn't fit text into image, giving up")
