#!/usr/bin/env python
# -*- coding: utf-8 -*-

from StringIO import StringIO
import json
import random
import argparse
import yt.wrapper as yt
import sys
from datetime import datetime
from captcha_tools import TarballContainer, NullContainer, generic_read_images_table, generic_store_results

def parse_args():
    parser = argparse.ArgumentParser(description='Prepare captcha images archive')

    parser.add_argument('--yt-proxy', default='hahn', help='YT proxy')

    parser.add_argument('--max-size', type=int, default=400, help='Resample all images with size above this')
    parser.add_argument('--min-size', type=int, default=100, help='Drop all images with size below this')

    parser.add_argument('--max-images', type=int, default=10000, help='Max images')
    parser.add_argument('--images-table', required=True, help='Path to the images table')
    parser.add_argument('--has-answer', default=False, action='store_true', help='Whether images are known')

    parser.add_argument('--debug-tar', help='Gzipped tarball with images for debugging')
    parser.add_argument('--result-tar', help='Gzipped tarball with resulting images')

    parser.add_argument('--progress-report', type=int, default=0, help='Report after processing each specified amount of images (0 to disable)')

    parser.add_argument('--image-format', choices=['gif', 'png', 'jpeg'], default='gif', help='Image format')

    return parser.parse_args()

def make_image_crop(record, mismatch_tolerance=1):
    from PIL import Image
    image = Image.open(StringIO(record["image"]))
    width, height = image.size
    xmin = width*record["crop_xmin"]
    xmax = width*record["crop_xmax"]
    ymin = height*record["crop_ymin"]
    ymax = height*record["crop_ymax"]
    xmin, xmax, ymin, ymax = map(lambda x: int(round(x)), (xmin, xmax, ymin, ymax))
    crop_height = ymax - ymin
    crop_width = xmax - xmin
    if crop_height != crop_width:
        mismatch = abs(crop_height - crop_width)
        if mismatch > mismatch_tolerance:
            raise RuntimeError("Size mismatch for record with id %s"%repr(record['id']))
        if crop_height > crop_width:
            ymax -= mismatch
        else:
            xmax -= mismatch

    assert xmax - xmin == ymax - ymin
    assert 0 <= xmin < xmax <= width
    assert 0 <= ymin < ymax <= height
    return image.crop((xmin, ymin, xmax, ymax))

def to_crop_coordinates(crop_min, crop_max, coord):
    crop_coord = (coord - crop_min) / (crop_max - crop_min)
    if crop_coord < 0.0:
        return 0.0
    if crop_coord > 1.0:
        return 1.0
    return crop_coord

def load_images(src, max_images, min_size, max_size, has_answer):
    from PIL import Image

    count = 0
    for record in generic_read_images_table(src, None):
        result = {
            'image': make_image_crop(record),
            'id': record['id'],
            'category_nominative': record['category'],
        }
        size, _ = result['image'].size
        if size < min_size:
            continue
        if size > max_size:
            result['image'] = result['image'].resize((max_size, max_size), Image.BICUBIC)
        if has_answer:
            axmin = to_crop_coordinates(record['crop_xmin'], record['crop_xmax'], record['true_xmin'])
            axmax = to_crop_coordinates(record['crop_xmin'], record['crop_xmax'], record['true_xmax'])
            aymin = to_crop_coordinates(record['crop_ymin'], record['crop_ymax'], record['true_ymin'])
            aymax = to_crop_coordinates(record['crop_ymin'], record['crop_ymax'], record['true_ymax'])
            result['answer'] = [[axmin, aymin], [axmax, aymax]]
        yield result
        count += 1
        if count == max_images:
            return


def store_results(image_iter, result_container, debug_container, image_format):
    def resultfunc(image_data):
        image = image_data['image']
        image_data = dict(image_data)
        del image_data['image']
        image_data['content_type'] = 'image/%s; charset=utf-8' % image_format
        return image, json.dumps(image_data)
    return generic_store_results(image_iter, result_container, debug_container, resultfunc, image_format=image_format)

def add_progress_report(interval, iterator):
    num = 0
    for item in iterator:
        num += 1
        if (num % interval == 0):
            print >>sys.stderr, "Loaded %d images"%num
        yield item

def main():
    args = parse_args()

    yt.config["proxy"]["url"] = args.yt_proxy
    yt.config["auto_merge_output"]["action"] = "merge"

    images_iter = load_images(args.images_table, args.max_images, args.min_size, args.max_size, args.has_answer)
    if (args.progress_report):
        images_iter = add_progress_report(args.progress_report, images_iter)

    result_container = TarballContainer(args.result_tar)

    if args.debug_tar:
        debug_container = TarballContainer(args.debug_tar)
    else:
        debug_container = NullContainer()

    store_results(images_iter, result_container, debug_container, args.image_format)

if __name__ == '__main__':
    main()
