import StringIO
import cPickle
import os
import shutil
import subprocess
import tempfile
import traceback

import attr
import yt.wrapper as yt

from Screenshotter import Screenshotter, Image
from SerpParser import SerpParser, SerpElementExampleSet, SerpElement
from settings import SerpSettings


@yt.aggregator
class SerpParserCombiner:
    OUT_TABLE_PARSED = 0
    OUT_TABLE_SKEL = 1
    OUT_TABLE_SERPONLY = 2
    OUT_TABLE_RAW_SKEL = 3
    OUT_TABLE_ERRORS = 4

    FIELD_SERP = 1
    FIELD_SCREENSHOT = 2
    FIELD_SKELETON = 3
    FIELD_RAW_SKELETON = 4
    FIELD_SERPONLY = 5
    ALL_FIELDS = (FIELD_SERP, FIELD_SCREENSHOT, FIELD_SKELETON, FIELD_RAW_SKELETON, FIELD_SERPONLY)

    def __init__(self, serp_settings, short_serp_settings, raise_exceptions=True, verbose=False, write_skeleton_table=False, fields_to_out=ALL_FIELDS):
        assert isinstance(serp_settings, SerpSettings)
        self.serp_settings = serp_settings
        self.serp_parser = SerpParser(serp_settings, short_serp_settings)
        self.screenshotter = Screenshotter(
            phantomjs_path="phantomjs_bundle/phantomjs",
            env={'FONTCONFIG_PATH': 'phantomjs_bundle'},
            verbose=verbose)
        self.screenshotter.resolution = serp_settings.RESOLUTION
        self.raise_exceptions = raise_exceptions
        self.write_skeleton_table = write_skeleton_table
        self.fields_to_out = fields_to_out

    def __call__(self, rec_list):
        if self.FIELD_SCREENSHOT in self.fields_to_out and os.path.exists('phantomjs_bundle.tgz'):
            subprocess.check_call("tar xzf phantomjs_bundle.tgz".split())
        for rec in rec_list:
            for outrec in self.map(rec):
                yield outrec

    def map(self, rec):
        query = rec['key']
        html = rec['value']
        try:
            html = self.serp_settings.cleanup_html(html)

            serp = self.serp_parser.parse_serp(html)
            serp.query_text = query

            if self.FIELD_SCREENSHOT in self.fields_to_out:
                serp_dir = tempfile.mkdtemp(prefix='SerpParserMapper_', dir='.')
                image_list = self.screenshotter.make_screenshots(serp.html_with_seanid, "[sean='1']", serp_dir)
                shutil.rmtree(serp_dir)

                seanid2image = dict((image.seanid, image) for image in image_list)
                for serp_element in serp.serp_elem_list:
                    serp_element.image = seanid2image.get(serp_element.seanid)

            if self.FIELD_SERP in self.fields_to_out:
                yield yt.create_table_switch(self.OUT_TABLE_PARSED)
                yield attr.asdict(serp)

            if self.FIELD_SKELETON in self.fields_to_out:
                for serp_element in serp.serp_elem_list:
                    yield yt.create_table_switch(self.OUT_TABLE_SKEL)
                    yield attr.asdict(serp_element)

            if self.FIELD_RAW_SKELETON in self.fields_to_out:
                for serp_element in serp.serp_elem_list:
                    skeleton_md5 = SerpParser.eval_skeleton_md5(serp_element.skeleton)
                    yield yt.create_table_switch(self.OUT_TABLE_RAW_SKEL)
                    yield dict(key=skeleton_md5, subkey=serp_element.seanid, value=query)

            if self.FIELD_SERPONLY in self.fields_to_out:
                serp.serp_elem_list = None
                yield yt.create_table_switch(self.OUT_TABLE_SERPONLY)
                yield attr.asdict(serp)

        except Exception, e:
            if self.raise_exceptions:
                raise
            else:
                out = StringIO.StringIO()
                traceback.print_exc(file=out)
                yield yt.create_table_switch(self.OUT_TABLE_ERRORS)
                yield dict(query=query, subkey=rec['subkey'], error="ERROR:" + out.getvalue(), html=html)


class SerpElementExampleSetReducer:
    def __init__(self, serp_settings):
        assert isinstance(serp_settings, SerpSettings)
        self.serp_settings = serp_settings

    def __call__(self, key, rec_list):
        ret = SerpElementExampleSet()
        query_set = set()
        for rec in rec_list:
            serp_elem = SerpElement(**rec)

            query_set.add(serp_elem.query_text)
            if len(ret.serp_elem_list) < self.serp_settings.HTM_SERP_ELEMENTS_SAMPLE_COUNT:
                ret.serp_elem_list.append(serp_elem)
            if not ret.skeleton:
                ret.skeleton = serp_elem.skeleton
                ret.short_skeleton = serp_elem.short_skeleton
            ret.short_skeleton_set.add(serp_elem.short_skeleton)
            ret.serp_elem_count += 1
            if serp_elem.image:
                serp_elem.image = Image(**serp_elem.image)
                if serp_elem.image.binary_content \
                        and serp_elem.image.height >= self.serp_settings.HTM_SERP_ELEMENT_IMAGE_DIMENSION_MIN \
                        and serp_elem.image.width >= self.serp_settings.HTM_SERP_ELEMENT_IMAGE_DIMENSION_MIN:
                    ret.count_good_images += 1
        ret.query_count = len(query_set)

        yield attr.asdict(ret)


class SerpSetFilterMapper:
    def __init__(self, serp_id_set):
        self.serp_id_set = serp_id_set

    def __call__(self, rec):
        if rec['id'] in self.serp_id_set:
            yield rec
