from __future__ import unicode_literals, print_function

import argparse
from contextlib import contextmanager
import xml.etree.ElementTree as etree
from xml.sax.saxutils import quoteattr

import os
import six
import sys

CNAME_TAGS = ('system-out', 'skipped', 'error', 'failure')
CNAME_PATTERN = '<![CDATA[{}]]>'
TAG_PATTERN = '<{tag}{attrs}>{text}</{tag}>'
ERROR_COUNT = 0
FAILURES_COUNT = 0


@contextmanager
def patch_etree_cname(etree):
    """
    Patch ElementTree's _serialize_xml function so that it will
    write text as CDATA tag for tags tags defined in CNAME_TAGS.
    """
    original_serialize = etree._serialize_xml

    def _serialize_xml(write, elem, *args, **kwargs):
        if elem.tag in CNAME_TAGS:
            attrs = ' '.join(
                ['{}={}'.format(k, quoteattr(v))
                 for k, v in sorted(elem.attrib.items())]
            )
            attrs = ' ' + attrs if attrs else ''
            text = CNAME_PATTERN.format(elem.text)
            write(TAG_PATTERN.format(
                tag=elem.tag,
                attrs=attrs,
                text=text
            ).encode('utf-8'))
        else:
            original_serialize(write, elem, *args, **kwargs)

    etree._serialize_xml = etree._serialize['xml'] = _serialize_xml

    yield

    etree._serialize_xml = etree._serialize['xml'] = original_serialize


def merge_trees(*trees):
    """
    Merge all given XUnit ElementTrees into a single ElementTree.
    This combines all of the children test-cases and also merges
    all of the metadata of how many tests were executed, etc.
    """
    print('Merging xml files ito a single file')
    first_tree = trees[0]
    first_root = first_tree.getroot()
    original_children = first_root.getchildren()

    if len(trees) == 0:
        return first_tree

    for tree in trees[1:]:
        root = tree.getroot()
        remove_list = []
        for children in root.getchildren():
            classname =  children.attrib['classname']
            rerun_name = children.attrib['name']
            for orig_child in original_children:
                original_name =  orig_child.attrib['name']
                if original_name == rerun_name:
                    first_root.remove(orig_child)
                    first_root.append(children)

        # combine root attributes which stores the number
        # of executed tests, skipped tests, etc
        for key, value in first_root.attrib.items():
            if not value.isdigit():
                continue
            if key == 'errors':
                if int(value) == int(root.attrib.get(key, '0')):
                    combined = six.text_type(int(value))
                elif int(root.attrib.get(key, '0')) == 0:
                    combined = six.text_type(root.attrib.get(key, '0'))
                else:
                    combined = six.text_type(int(value) - int(root.attrib.get(key, '0')))
                global ERROR_COUNT
                ERROR_COUNT = combined
            elif key == 'failures':
                if int(value) == int(root.attrib.get(key, '0')):
                    combined = six.text_type(int(value))

                elif int(root.attrib.get(key, '0')) == 0:
                    combined = six.text_type(root.attrib.get(key, '0'))
                else:
                    combined = six.text_type(int(value) - int(root.attrib.get(key, '0')))
                global FAILURES_COUNT
                FAILURES_COUNT = combined
            else:
                combined = six.text_type(int(value))

            first_root.set(key, combined)

    return first_tree


def merge_xunit(files, output, callback=None):
    """
    Merge the given xunit xml files into a single output xml file.

    If callback is not None, it will be called with the merged ElementTree
    before the output file is written (useful for applying other fixes to
    the merged file). This can either modify the element tree in place (and
    return None) or return a completely new ElementTree to be written.
    """
    trees = []

    for f in files:
        trees.append(etree.parse(f))

    merged = merge_trees(*trees)

    if callback is not None:
        result = callback(merged)
        if result is not None:
            merged = result

    with patch_etree_cname(etree):
        merged.write(output, encoding='utf-8', xml_declaration=True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Merging xmls into a single xml.')
    parser.add_argument('-o', action='store', default='',
                        dest='output_dir',
                        help='dir for xmls')
    results = parser.parse_args()
    files_list = []
    for root, dirs, files in os.walk(results.output_dir):
        for filename in files:
            if filename.endswith('.xml'):
                files_list.append(os.path.join(root, filename))
    print("***************")
    print(files_list)
    print("***************")
    merged_xml = results.output_dir + '/merged.xml'
    if len(files_list) > 0 :
        merge_xunit(files=files_list, output=merged_xml)
    if ERROR_COUNT > 0 or FAILURES_COUNT > 0:
        sys.exit(1)
