import argparse
import attr
import itertools
import os
import shutil
import time
import uuid
from collections import Counter

import yt.wrapper as yt

from SerpParser import SerpElementExampleSet, SerpElement, Serp, SerpParser
from SerpSet2Json import SerpSet2Json
from SerpSetSimpleHtm import SerpSetSimpleHtm
from Screenshotter import Image, Screenshotter
from parse_serp_mr_jobs import SerpParserCombiner, SerpElementExampleSetReducer, SerpSetFilterMapper
from settings import SerpSettings, GoogleSerpSettingsTouch, GoogleSerpSettingsTouchShort

__author__ = 'irlab'


STAGES = (
    'parse_html',
    'make_examples',
    'make_sample',
)


DEBUG = False


def process_mr(src_tables, out_mr_table_prefix, out_dir, continue_from):
    yt.config['spec_defaults']['job_io'] = {'table_writer': {'max_row_weight': 128 * yt.common.MB}}

    # serp_settings = GoogleSerpSettingsTouch()
    serp_settings = GoogleSerpSettingsTouchShort()
    serp_short_settings = GoogleSerpSettingsTouchShort()
    assert isinstance(serp_settings, SerpSettings)

    serp_parser_combiner = SerpParserCombiner(serp_settings, serp_short_settings, raise_exceptions=False, verbose=False)

    serponly_table_name = out_mr_table_prefix + "_serponly"
    examples_table_name = out_mr_table_prefix + "_examples"
    serponly_filtered_table_name = out_mr_table_prefix + "_serponly_filt"

    dstTables = [
        out_mr_table_prefix + "_parsed",
        out_mr_table_prefix + "_skel",
        serponly_table_name,
        out_mr_table_prefix + "_raw_skel",
        out_mr_table_prefix + "_errors"
    ]

    if not continue_from:
        yt.create('table', dstTables[0], recursive=True, ignore_existing=True)
        for table in dstTables + [examples_table_name, serponly_filtered_table_name]:
            if yt.exists(table):
                yt.remove(table)

    continue_from_index = STAGES.index(continue_from) if continue_from else 0
    if continue_from_index <= STAGES.index('parse_html'):
        parse_html(dstTables, serp_parser_combiner, src_tables)

    if continue_from_index <= STAGES.index('make_examples'):
        make_examples(examples_table_name, out_mr_table_prefix, serp_settings)

    if continue_from_index <= STAGES.index('make_sample'):
        example_sets = make_sample(examples_table_name, serponly_filtered_table_name, serponly_table_name)
        make_output(example_sets, out_dir, serp_settings, serponly_filtered_table_name)


def parse_html(dstTables, serp_parser_combiner, src_tables):
    print time.ctime(), 'parse_html'
    files = [
        "make_screen.js",
        "phantomjs_bundle.tgz"
    ]
    for file in files:
        assert os.path.isfile(file), "file not exists " + file

    print time.ctime(), 'run_map serp_parser_combiner'
    spec = {
        "title": "serp_anatomy serp_parser_combiner",
        "data_size_per_job": 32 * yt.common.MB,
        "mapper": {
            "memory_limit": 8 * yt.common.GB,
            "tmpfs_size": 1 * yt.common.GB,
            "tmpfs_path": ".",
            "copy_files": True
        },
    }
    if DEBUG:
        spec["data_size_per_job"] = 2 * yt.common.MB
        spec.update({
            "data_size_per_job": 2 * yt.common.MB,
            "auto_merge": {
                "mode": "disabled"
            },
        })
    yt.run_map(
        serp_parser_combiner,
        src_tables,
        dstTables,
        local_files=files,
        spec=spec
    )


def make_examples(examples_table_name, out_mr_table_prefix, serp_settings):
    print time.ctime(), 'make_examples'
    yt.run_sort(
        out_mr_table_prefix + "_skel",
        out_mr_table_prefix + "_skel",
        sort_by=['skeleton_md5', 'seanid'],
        spec=dict(title="serp_anatomy make_examples sort")
    )
    print time.ctime(), 'run_reduce SerpElementExampleSetReducer'
    yt.run_reduce(
        SerpElementExampleSetReducer(serp_settings),
        out_mr_table_prefix + "_skel",
        examples_table_name,
        reduce_by=['skeleton_md5'],
        sort_by=['skeleton_md5', 'seanid'],
        spec=dict(title="serp_anatomy make_examples")
    )


def make_sample(examples_table_name, serponly_filtered_table_name, serponly_table_name):
    print time.ctime(), 'make_sample'
    example_sets = []
    serp_id2example_count = Counter()
    """:type : list[SerpElementExampleSet]"""
    print time.ctime(), 'read_table', examples_table_name
    for rec in yt.read_table(examples_table_name):
        serp_element_example_set = SerpElementExampleSet(**rec)
        example_sets.append(serp_element_example_set)
        for i, serp_element in enumerate(serp_element_example_set.serp_elem_list):
            serp_element = SerpElement(**serp_element)
            if serp_element.image:
                serp_element.image = Image(**serp_element.image)
            serp_element_example_set.serp_elem_list[i] = serp_element
            serp_id2example_count[serp_element.serp_id] += 1
    serp_id_set = set(serp_id2example_count.keys())
    print time.ctime(), 'run_map SerpSetFilterMapper'
    yt.run_map(
        SerpSetFilterMapper(serp_id_set),
        serponly_table_name,
        serponly_filtered_table_name,
        spec = dict(title="serp_anatomy make_sample")
    )
    return example_sets


def make_output(example_sets, out_dir, serp_settings, serponly_filtered_table_name):
    print time.ctime(), 'make_output'
    print time.ctime(), 'read_table', serponly_filtered_table_name
    serp_list = []
    for rec in yt.read_table(serponly_filtered_table_name):
        serp = Serp(**rec)
        serp_list.append(serp)
    print "write output dir", out_dir
    uuid_folder = str(uuid.uuid4())
    SerpSetSimpleHtm(serp_settings, uuid_folder).create_out_dir(out_dir, serp_list, example_sets)
    SerpSet2Json(serp_settings, uuid_folder).create_out_json(out_dir, serp_list, example_sets)


def upload_local_dir_to_mr(input_dir, html_table):
    yt.create('table', html_table, recursive=True, ignore_existing=True)

    html_record_list = []
    for fname in os.listdir(input_dir):
        if ".htm" not in fname: continue
        print ">>", fname
        with open(os.path.join(input_dir, fname)) as f:
            html = f.read()
            html_record_list.append(dict(key=fname, subkey='', value=html))
    yt.write_table(html_table, html_record_list)


def process_local_dir(input_dir, out_dir):
    try:
        if os.path.exists(out_dir):
            shutil.rmtree(out_dir)
        os.mkdir(out_dir)
    except:
        pass

    # serp_settings = GoogleSerpSettingsTouch()
    serp_settings = GoogleSerpSettingsTouchShort()
    serp_settings.HTM_SERP_ELEMENTS_SKIP_IF_QUERIES_LE = -1
    serp_short_settings = GoogleSerpSettingsTouchShort()
    serp_parser = SerpParser(serp_settings, serp_short_settings)
    phantomjs_path = "phantomjs.exe" if os.name == 'nt' else "phantomjs"
    screenshotter = Screenshotter(phantomjs_path=phantomjs_path, verbose=True)

    serp_list = []
    serp_element_list = []
    for fname in os.listdir(input_dir):
        if ".htm" not in fname: continue
        print ">>", fname
        with open(os.path.join(input_dir, fname)) as f:
            html = f.read()

        serp = serp_parser.parse_serp(html)

        serp_dir = os.path.join(out_dir, fname.replace(".html", ""))
        try:
            os.mkdir(serp_dir)
        except:
            pass
        image_list = screenshotter.make_screenshots(serp.html_with_seanid, "[sean='1']", serp_dir)

        seanid2image = dict((image.seanid, image) for image in image_list)
        for serp_element in serp.serp_elem_list:
            serp_element.image = seanid2image.get(serp_element.seanid)
            serp_element_list.append(serp_element)

        serp_list.append(serp)

    example_sets = []
    serp_element_reducer = SerpElementExampleSetReducer(serp_settings)
    serp_element_list.sort(key=lambda serp_element: (serp_element.skeleton_md5, serp_element.seanid))
    for skeleton_md5, recs_list in itertools.groupby(serp_element_list, key=lambda serp_element: serp_element.skeleton_md5):
        recs_list = map(attr.asdict, recs_list)
        for rec in serp_element_reducer(skeleton_md5, recs_list):
            serp_element_example_set = SerpElementExampleSet(**rec)
            for i, serp_element in enumerate(serp_element_example_set.serp_elem_list):
                serp_element = SerpElement(**serp_element)
                if serp_element.image:
                    serp_element.image = Image(**serp_element.image)
                serp_element_example_set.serp_elem_list[i] = serp_element
            example_sets.append(serp_element_example_set)

    onto_htm = os.path.join(out_dir, "onto.htm")
    print "write file", onto_htm
    uuid_folder = str(uuid.uuid4())
    SerpSetSimpleHtm(serp_settings, uuid_folder).create_out_dir(out_dir, serp_list, example_sets)
    SerpSet2Json(serp_settings, uuid_folder).create_out_json(out_dir, serp_list, example_sets)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Serp anatomy parser for MR\n' +
        'You need to have static phantomjs binary in your path https://github.com/macbre/phantomjs/releases/tag/2.0.0')
    parser.add_argument('--input_dir', metavar='input_dir', dest='input_dir',
                        required=False, help="input dir with *.htm tables")
    parser.add_argument('-i', metavar='input_table', dest='input_table',
                        required=True, help="MR table")
    parser.add_argument('-o', metavar='output_table', dest='output_table',
                        required=True, help="MR output table")
    parser.add_argument('-d', metavar='output_dir', dest='output_dir',
                        required=True, help="output dir with anatomy")
    parser.add_argument('--continue_from', choices=STAGES, help="continue from a specific stage")
    parser.add_argument('--local', action="store_true", help="run locally using --input_dir <input_dir> -d <output_dir>")
    parser.add_argument('--debug', action="store_true", help="some optimizations to run faster on small data")
    args = parser.parse_args()

    DEBUG = args.debug

    if args.local:
        process_local_dir(args.input_dir, args.output_dir)

    else:
        if args.input_dir:
            upload_local_dir_to_mr(args.input_dir, args.input_table)

        process_mr([args.input_table], args.output_table, args.output_dir, args.continue_from)
