import argparse
import random
import logging
import yt.wrapper as yt


def is_image_acceptable(record, params):
    if params.get('only_hard_pictures') and record["Type"] not in ["DIFF"]:
        return False

    if 'WordGT' in record:
        answer = record['WordGT'].decode('utf-8')

        if params.get('only_single_word') and len(answer.split(' ')) > 1:
            return False

        if params.get('word_min_length') and len(answer) < params['word_min_length']:
            return False

        if params.get('word_max_length') and len(answer) > params['word_max_length']:
            return False

    return True


def tail_by_fields(src, dst, fields, count):
    with yt.Transaction():
        yt.run_sort(src, dst, sort_by=fields)
        row_count = yt.get(dst + '/@row_count')
        if row_count > count:
            yt.run_merge(yt.TablePath(dst, start_index=row_count-count), dst)


def prepare_latest_input_images(src, dst, answer_column, filter_params):
    def mapper(rec):
        if not is_image_acceptable(rec, filter_params):
            return
        result = {
            'unique_name': rec['Unique_name'],
            'upload_date': rec['Upload_date'],
            'image': rec['Image']
        }
        if answer_column is not None:
            result['answer'] = rec[answer_column]
        yield result

    yt.run_map(mapper, src, dst)


def remove_by_blacklisted_fields(src, blacklist, dst, fields):
    TABLE_INDEX_TAG = '__table_index'

    def mapper(rec):
        rec[TABLE_INDEX_TAG] = rec['@table_index']
        del rec['@table_index']
        yield rec

    def reducer(key, recs):
        result = []
        for rec in recs:
            if rec[TABLE_INDEX_TAG] == 0:
                del rec[TABLE_INDEX_TAG]
                result.append(rec)
            else:
                return
        for rec in result:
            yield rec

    yt.run_map_reduce(mapper, reducer, [src, blacklist], dst, reduce_by=fields,
                      format=yt.YsonFormat(control_attributes_mode='row_fields'))


def extract_unprocessed_images(previous_batch, processed_lists, dst):
    def mapper(rec):
        from_processed_list = 'unique_name' in rec
        from_previous_batch = 'unknown_id' in rec
        assert from_processed_list or from_previous_batch
        assert not (from_processed_list and from_previous_batch)
        if from_processed_list:
            yield {
                'unique_name': rec['unique_name']
            }
        else:
            yield {
                'unique_name': rec['unknown_id'],
                'image': rec['unknown_image']
            }

    def reducer(key, recs):
        image = None
        for rec in recs:
            if 'image' not in rec:
                return
            image = rec['image']

        if image is not None:
            yield {
                'unique_name': key['unique_name'],
                'image': image
            }

    with yt.Transaction():
        src = [previous_batch] + processed_lists
        yt.run_map_reduce(mapper, reducer, src, dst, reduce_by=['unique_name'], mapper_memory_limit=1024**3)


class MapBuildBase(object):
    def __init__(self, known_per_unknown):
        self.known_per_unknown = known_per_unknown
        self.next_index = 0

    def start(self):
        self.known_images = list(yt.YsonFormat().load_rows(open('known_images')))

    def __call__(self, rec):
        for i in range(self.known_per_unknown):
            index = self.next_index
            self.next_index += 1
            known_image_data = random.choice(self.known_images)

            if i % 2 == 0:
                layout = 'unknown,known'
            else:
                layout = 'known,unknown'

            yield {
                'index': index,
                'known_id': known_image_data['unique_name'],
                'unknown_id': rec['unique_name'],
                'known_image': known_image_data['image'],
                'unknown_image': rec['image'],
                'layout': layout,
                'known_answer': known_image_data['answer']
            }


def build_batch(unknown_src, known_src, dst, known_per_unknown):
    def map_add_random_tag(rec):
        rec['random_tag'] = random.random()
        yield rec

    with yt.Transaction():
        yt.run_map(map_add_random_tag, unknown_src, dst)
        yt.run_sort(dst, sort_by=['random_tag'])

        known_format = yt.yson.to_yson_type("yson")
        known_file = yt.yson.to_yson_type(known_src,
                                          attributes={"format": known_format, "file_name": 'known_images'})
        yt.run_map(MapBuildBase(known_per_unknown), dst, yt.TablePath(dst, sorted_by=['index']),
                   spec={'job_count': 1}, yt_files=[known_file])


def parse_args():
    parser = argparse.ArgumentParser(description='Prepare captcha images archive')

    parser.add_argument('--yt-proxy', default='hahn', help='YT proxy')

    parser.add_argument('--word-min-length', type=int, default=5, help='Min acceptable word length for known images')
    parser.add_argument('--word-max-length', type=int, default=10, help='Max acceptable word length for known images')

    parser.add_argument('--max-unknown-images', type=int, default=25000, help='Max images without answers')
    parser.add_argument('--known-per-unknown', type=int, default=16, help='How many known images are to be matched with every unknown')

    parser.add_argument('--known-images-table', required=True, help='Path to the known images table')
    parser.add_argument('--unknown-images-table', required=True, help='Path to the unknown images table')
    parser.add_argument('--previous-batch-table', help='Path to the table with previous batch')
    parser.add_argument('--processed-images-table', nargs='*', help='Path to the table containing list of already processed images')
    parser.add_argument('--mandatory-unknown-images-table', help='Path to the table containing list of mandatory images')

    parser.add_argument('--known-answer-column', default='WordGT', help='Name of answer column in known images table (default: "WordGT")')

    parser.add_argument('--destination-table', required=True, help='Path to the destination')

    return parser.parse_args()


def main():
    logging.basicConfig(
        level=logging.INFO,
        format="[%(filename)s:%(lineno)d] %(levelname)-8s [%(asctime)s]  %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    args = parse_args()

    yt.config['proxy']['url'] = args.yt_proxy
    yt.config['auto_merge_output']['action'] = 'merge'

    filter_params = {
        'only_hard_pictures': True,
        'only_single_word': True,
        'word_min_length': args.word_min_length,
        'word_max_length': args.word_max_length,
    }

    with yt.Transaction():
        with yt.TempTable() as previous_unprocessed, yt.TempTable() as last_unknown, yt.TempTable() as last_known, yt.TempTable() as mandatory_unknown, yt.TempTable() as mandatory_unknown_map:
            unknown_src = []
            required_unknown_src_total_size = args.max_unknown_images
            if args.mandatory_unknown_images_table and yt.exists(args.mandatory_unknown_images_table):
                logging.info("Process mandatory unknown images")
                mandatory_unknown_images_count = yt.get(args.mandatory_unknown_images_table + '/@row_count')
                if mandatory_unknown_images_count > args.max_unknown_images:
                    yt.run_merge(yt.TablePath(args.mandatory_unknown_images_table, end_index=args.max_unknown_images), mandatory_unknown)
                    unknown_src.append(mandatory_unknown)
                    required_unknown_src_total_size = 0
                else:
                    mandatory_unknown = args.mandatory_unknown_images_table
                    required_unknown_src_total_size -= mandatory_unknown_images_count
                prepare_latest_input_images(mandatory_unknown, mandatory_unknown_map, None, {})
                unknown_src.append(mandatory_unknown_map)
                logging.info("Number of unknown images required: %s", required_unknown_src_total_size)

            if args.previous_batch_table and required_unknown_src_total_size > 0:
                logging.info("Process previous batch unknown images")
                # вытаскивыем из previous_batch те картинки, которых не было в processed
                extract_unprocessed_images(args.previous_batch_table, args.processed_images_table, previous_unprocessed)
                previous_unprocessed_images_count = yt.get(previous_unprocessed + '/@row_count')
                if previous_unprocessed_images_count > required_unknown_src_total_size:
                    yt.run_merge(yt.TablePath(previous_unprocessed, end_index=required_unknown_src_total_size), previous_unprocessed)
                    required_unknown_src_total_size = 0
                else:
                    required_unknown_src_total_size -= previous_unprocessed_images_count
                unknown_src.append(previous_unprocessed)
                logging.info("Number of unknown images required: %s", required_unknown_src_total_size)

            if required_unknown_src_total_size > 0:
                # если не хватает, пытаемся добрать из unknown_images_table
                prepare_latest_input_images(args.unknown_images_table, last_unknown, None, {})
                if args.previous_batch_table:
                    # удаляем из last_unknown то, что было в previous_unprocessed
                    remove_by_blacklisted_fields(last_unknown, previous_unprocessed, last_unknown, ['unique_name'])
                # оставляем сколько не хватает (последние по upload_date)
                tail_by_fields(last_unknown, last_unknown, ['upload_date'], required_unknown_src_total_size)
                unknown_src.append(last_unknown)
                required_unknown_src_total_size -= yt.get(last_unknown + '/@row_count')
                logging.info("Number of unknown images required: %s", required_unknown_src_total_size)

            if required_unknown_src_total_size > 0:
                logging.warning("Number of unknown images required is greater than 0: %s", required_unknown_src_total_size)

            known_count = args.max_unknown_images * args.known_per_unknown
            prepare_latest_input_images(args.known_images_table, last_known, args.known_answer_column, filter_params)
            tail_by_fields(last_known, last_known, ['upload_date'], known_count)

            build_batch(unknown_src, last_known, args.destination_table, args.known_per_unknown)


if __name__ == '__main__':
    main()
