from sandbox import sdk2
from sandbox import common
from sandbox.projects.common.arcadia import sdk
from sandbox.projects.common import error_handlers as eh
from sandbox.projects.logs.common import GetRevision
from sandbox.sdk2.helpers import subprocess, ProcessLog

import logging
import re
import os.path


def remove_comments(data):
    res = ''
    current_idx = 0
    while current_idx < len(data):
        comment_begin = data.find('<!--', current_idx)
        comment_end = data.find('-->', current_idx)
        if comment_begin == -1:
            assert comment_end == -1, 'Bad comment: no end (-->)'
        else:
            assert comment_begin < comment_end, 'Bad comment: end (-->) before begin (<!--): {} {}'.format(comment_begin, comment_end)

        if comment_begin == -1:
            res += data[current_idx:]
            break
        res += data[current_idx:comment_begin]
        current_idx = comment_end + len('-->')
    return res


def read_file(filename):
    with open(filename, 'r') as fin:
        data = fin.read()
        data = remove_comments(data)
    return [line.strip() for line in data.split('\n')]


COMPLEX_TYPE_BEGIN_REGEXP = re.compile('<xs[d]?:complexType name=.*')


def is_complex_type_begin(line):
    return bool(COMPLEX_TYPE_BEGIN_REGEXP.match(line))


GROUP_TYPE_BEGIN_REGEXP = re.compile('<xs[d]?:group name=.*')


def is_group_type_begin(line):
    return bool(GROUP_TYPE_BEGIN_REGEXP.match(line))


COMPLEX_TYPE_NAME_REGEXP = re.compile('name="(.*?)"')


def get_name_from_line(line):
    names = COMPLEX_TYPE_NAME_REGEXP.findall(line)
    assert len(names) == 1, "Bad complexType name: {}, names: {}".format(line, names)
    return names[0]


COMPLEX_TYPE_END_REGEXP = re.compile('</xs[d]?:complexType>')


def is_complex_type_end(line):
    return bool(COMPLEX_TYPE_END_REGEXP.match(line))


GROUP_TYPE_END_REGEXP = re.compile('</xs[d]?:group>')


def is_group_type_end(line):
    return bool(GROUP_TYPE_END_REGEXP.match(line))


EVENT_BASE_REGEXP = re.compile('<xs[d]?:extension base="(.*?)">')


def is_event_base(line):
    base = EVENT_BASE_REGEXP.findall(line)
    assert len(base) <= 1, 'Multiple base of event: {}'.format(line)
    if len(base) == 1 and base[0] != 'server_request_with_id':
        assert base[0] == 'event', 'Base of complexType must be event {}'.format(line)
        return True
    return False


def is_group(diff, begin, end):
    return is_group_type_begin(diff[begin])


def is_event(diff, begin, end):
    if is_group(diff, begin, end):
        return False

    was_event_base = False
    for idx in range(begin, end):
        if is_event_base(diff[idx]):
            assert not was_event_base, 'Event with base dublication: {}'.format(diff[idx])
            was_event_base = True
    return was_event_base


STABLE_ATTRIBUTE_REGEXP = re.compile('sc:stable="(.*?)"')


def get_stable_attribute_from_line(line):
    res = STABLE_ATTRIBUTE_REGEXP.findall(line)
    assert len(res) <= 1, 'Multiple stable attributes: {}'.format(line)
    if len(res) == 0:
        return None
    assert res[0] == 'true' or res[0] == 'false', 'Stable attribute must be true or false: {}'.format(line)
    return res[0]


GROUP_IN_EVENT_REGEXP = re.compile('<xs[d]?:group ref="(.*?)"[/]?>')


def get_group_in_event_from_line(line):
    res = GROUP_IN_EVENT_REGEXP.findall(line)
    assert len(res) <= 1, 'Multiple group in event in one line: {}'.format(line)
    if len(res) == 0:
        return None
    return res[0]


def find_next_complex_type_or_group(diff, start_idx):
    type_begin_idx = None
    type_end_idx = None
    name = None
    stable = None
    current_type = None
    groups = []
    for idx in range(start_idx, len(diff)):
        line = diff[idx]
        if is_complex_type_begin(line):
            assert type_begin_idx is None, 'Multiple event begin'
            current_type = 'complex_type'
            name = get_name_from_line(line)
            stable = get_stable_attribute_from_line(line)
            type_begin_idx = idx
        elif is_complex_type_end(line):
            assert type_end_idx is None, 'Bad line: {} {} {}'.format(idx, type_begin_idx, type_end_idx)
            assert current_type == 'complex_type'
            type_end_idx = idx
        elif is_group_type_begin(line):
            assert type_begin_idx is None, 'Bad line: {} {} {}'.format(idx, type_begin_idx, type_end_idx)
            current_type = 'group'
            name = get_name_from_line(line)
            type_begin_idx = idx
        elif is_group_type_end(line) and current_type == 'group':
            assert type_end_idx is None, 'Bad line: {} {} {}'.format(idx, type_begin_idx, type_end_idx)
            type_end_idx = idx

        group_from_line = get_group_in_event_from_line(line)
        if group_from_line:
            groups.append(group_from_line)

        if type_begin_idx is not None and type_end_idx is not None:
            return {'begin': type_begin_idx,
                    'end': type_end_idx,
                    'name': name,
                    'stable': stable,
                    'groups': groups}

    return None


def find_all_complex_types_and_groups(diff):
    res = dict()
    start_idx = 0
    while start_idx < len(diff):
        next_type = find_next_complex_type_or_group(diff, start_idx)
        if next_type is None:
            break
        assert next_type['name'] not in res, 'Events with same name: {}'.format(next_type)
        res[next_type['name']] = next_type
        start_idx = next_type['end'] + 1
    return res


COMPLEX_TYPE_FIELD_DECLARATION_NAME_TYPE_REGEXP = re.compile('.*<xs[d]?:element name="(.*?)".*type="(.*?)"')
COMPLEX_TYPE_FIELD_DECLARATION_TYPE_NAME_REGEXP = re.compile('.*<xs[d]?:element type="(.*?)".*name="(.*?)"')
IS_NILLABLE_FIELD_REGEXP = re.compile('.*<xs[d]?:element.*nillable="true"')
FIELD_VERSION_REGEXP = re.compile('.*<xs[d]?:appinfo source="scarab:version">([0-9]*)')


def get_complex_type_fields(diff, begin, end):
    res = []

    def add_fields_to_res(fields, is_nillable, version):
        assert len(fields) <= 1, 'Bad event fields: {}'.format(fields)
        if len(fields) == 1:
            assert len(fields[0]) == 2, 'Bad event fields: {}'.format(fields)
            res.append({'name': fields[0][0],
                        'type': fields[0][1],
                        'is_nillable': is_nillable,
                        'version': version})

    def find_field_version(diff, begin, end):
        if begin == end:
            return 1
        line = diff[begin]
        res = FIELD_VERSION_REGEXP.findall(line)
        assert len(res) <= 1, 'Bad field version: {}'.format(line)
        if len(res) == 1 and res[0] != '':
            return int(res[0])

        if begin + 1 == end:
            return 1
        is_new_field = bool(COMPLEX_TYPE_FIELD_DECLARATION_NAME_TYPE_REGEXP.match(line)) or \
                       bool(COMPLEX_TYPE_FIELD_DECLARATION_TYPE_NAME_REGEXP.match(line))
        line = diff[begin + 1]
        res = FIELD_VERSION_REGEXP.findall(line)
        assert len(res) <= 1, 'Bad field version: {}'.format(line)
        if not is_new_field and len(res) == 1 and res[0] != '':
            return int(res[0])
        return 1

    for idx in range(begin, end):
        line = diff[idx]
        is_nillable = bool(IS_NILLABLE_FIELD_REGEXP.match(line))
        version = find_field_version(diff, idx + 1, end)
        add_fields_to_res(COMPLEX_TYPE_FIELD_DECLARATION_NAME_TYPE_REGEXP.findall(line), is_nillable, version)
        add_fields_to_res(COMPLEX_TYPE_FIELD_DECLARATION_TYPE_NAME_REGEXP.findall(line), is_nillable, version)

    return res


SCARAB_KEY_WORD = re.compile('scarab:')
NAME_FIELD_REGEXP = re.compile('([a-zA-Z]){1,}[a-zA-Z0-9_]*$')
NAME_RULE_VIOLATIONS = ['percentile_90', ]


def validate_name_field(name):
    assert not SCARAB_KEY_WORD.match(name), "Name field must differ from scarab (scarab:) key word: {}".format(name)
    name_field_match = bool(NAME_FIELD_REGEXP.match(name))
    if not name_field_match:
        assert name in NAME_RULE_VIOLATIONS, "Name field contains bad symbols: {}".format(name)


TYPE_FIELD_REGEXP = re.compile('([a-zA-Z]_?:?)*[a-zA-Z0-9]$')


def validate_type_field(tp):
    assert TYPE_FIELD_REGEXP.match(tp), "Type field hasn't match: {}".format(tp)
    assert not SCARAB_KEY_WORD.match(tp), "Type field must differ from scarab key word: {}".format(tp)


def validate_name_type_fileds(fields):
    for field in fields:
        validate_name_field(field['name'])
        validate_type_field(field['type'])


def extract_and_validate_name_type(name, fields):
    res = dict()
    for field in fields:
        validate_type_field(field['type'])
        validate_name_field(field['name'])
        assert field['name'] not in res, 'Field {} dublication in type {}'.format(field['name'], name)
        res[field['name']] = (field['type'], field['is_nillable'], field['version'],)
    return res


VERSION_RULE_VIOLATION = {'market.xsd': ['click_order_event_properties', 'cpa_click_event_properties', 'order_event_properties', ],
                          'pdb_informers.xsd': ['block', 'channel', ], }


def validate_complex_type_and_group_fields(xsd_filename, name, type_is_group, stable_groups, stable_fields, test_fields):
    for filename, bad_names in VERSION_RULE_VIOLATION.iteritems():
        if xsd_filename.endswith(filename) and name in bad_names:
            return

    if type_is_group:
        if name in stable_groups:
            assert extract_and_validate_name_type(name, stable_fields) == extract_and_validate_name_type(name, test_fields), \
                'Group changed: {}'.format(name)
    else:
        stable_fields_info = extract_and_validate_name_type(name, stable_fields)
        test_fields_info = extract_and_validate_name_type(name, test_fields)
        assert stable_fields_info.viewitems() <= test_fields_info.viewitems(), 'Deleted fields in complex type: {}'.format(name)
        for field_name, info in test_fields_info.iteritems():
            if field_name not in stable_fields_info:
                assert info[1], 'Added not nillable field ({}), in complex type: {}'.format(field_name, name)

    for field in test_fields:
        assert field['version'] == 1, 'Version field in complex type or group: {}'.format(name)


def validate_event_type_fields(name, stable_fields, test_fields):
    max_stable_version = 0
    for field in stable_fields:
        assert field in test_fields, "Field is deleted: {}".format(field['name'])
        max_stable_version = max(max_stable_version, field['version'])

    min_test_version = -1
    for field in test_fields:
        if field not in stable_fields:
            assert field['version'] > max_stable_version, 'New field added without incrementing version: {}'.format(field)
            if min_test_version == -1 or field['version'] < min_test_version:
                min_test_version = field['version']
    if min_test_version != -1:
        assert max_stable_version + 1 == min_test_version, 'Version missed, was: {}, became: {}'.format(max_stable_version, min_test_version)


def validate_changes_in_complex_types_and_groups(xsd_filename,
                                                 stable_file, stable_complex_types_and_groups, stable_groups,
                                                 test_file, test_complex_types_and_groups):
    for name, stable_type in stable_complex_types_and_groups.iteritems():
        if stable_type['stable'] == 'false':
            continue

        assert name in test_complex_types_and_groups, "Complex type is deleted: {}".format(name)

        test_type = test_complex_types_and_groups[name]
        stable_fields = get_complex_type_fields(stable_file, stable_type['begin'], stable_type['end'])
        test_fields = get_complex_type_fields(test_file, test_type['begin'], test_type['end'])

        if stable_type['stable'] is None:
            assert test_type['stable'] is None, "Can't add stable attribute: {}".format(name)
            type_is_group = is_group(stable_file, stable_type['begin'], stable_type['end'])
            validate_complex_type_and_group_fields(xsd_filename, name, type_is_group, stable_groups, stable_fields, test_fields)
        else:
            assert test_type['stable'] == 'true', 'Stable attribute changed from true to false: {}'.format(stable_type)
            validate_event_type_fields(name, stable_fields, test_fields)
            assert stable_type['groups'] == test_type['groups'], "Can't change groups in stable event: {}".format(name)


def is_name_end_rule_violation(xsd_filename, type_name):
    is_test_safe_click = xsd_filename.endswith('layout.xsd') and type_name == 'TestSafeClickEvent'
    is_search_app_settings_event_v450 = xsd_filename.endswith('mobile.xsd') and type_name == 'search_app_settings_event_v450'
    return is_test_safe_click and is_search_app_settings_event_v450


def validate_complex_types_and_groups(xsd_filename,
                                      stable_file, stable_complex_types_and_groups,
                                      test_file, test_complex_types_and_groups):
    for name, test_type in test_complex_types_and_groups.iteritems():
        if name in stable_complex_types_and_groups:
            stable_type = stable_complex_types_and_groups[name]
            assert is_event(test_file, test_type['begin'], test_type['end']) == \
                   is_event(stable_file, stable_type['begin'], stable_type['end']), \
                   "Events attritube changed: file {}, {}".format(xsd_filename, name)
            assert is_group(test_file, test_type['begin'], test_type['end']) == \
                   is_group(stable_file, stable_type['begin'], stable_type['end']), \
                   "Group attribute changed: file {}, {}".format(xsd_filename, name)

        if is_name_end_rule_violation(xsd_filename, test_type['name']):
            if is_event(test_file, test_type['begin'], test_type['end']):
                assert name.endswith('_event') or name.endswith('_error'), 'Event name must ends with "event" or "error" word {}'.format(name)

        test_fields = get_complex_type_fields(test_file, test_type['begin'], test_type['end'])
        extract_and_validate_name_type(name, test_fields)


def get_groups_from_stable_events(complex_types_and_groups):
    stable_groups = set()
    for name, cur_type in complex_types_and_groups.iteritems():
        if cur_type['stable'] == 'true':
            stable_groups.update(cur_type['groups'])
    return stable_groups


SIMPLE_TYPE_BEGIN_REGEXP_1 = re.compile('<xsd:simpleType name=.*')
SIMPLE_TYPE_BEGIN_REGEXP_2 = re.compile('<xs:simpleType name=.*')


def is_simple_type_begin(line):
    return bool(SIMPLE_TYPE_BEGIN_REGEXP_1.match(line)) or bool(SIMPLE_TYPE_BEGIN_REGEXP_2.match(line))


SIMPLE_TYPE_END_REGEXP_1 = re.compile('</xsd:simpleType>')
SIMPLE_TYPE_END_REGEXP_2 = re.compile('</xs:simpleType>')


def is_simple_type_end(line):
    return bool(SIMPLE_TYPE_END_REGEXP_1.match(line)) or bool(SIMPLE_TYPE_END_REGEXP_2.match(line))


def find_next_simple_type(diff, start_idx):
    type_begin_idx = None
    type_end_idx = None
    name = None
    for idx in range(start_idx, len(diff)):
        line = diff[idx]
        if is_simple_type_begin(line):
            assert type_begin_idx is None, 'Bad line: {} {} {}'.format(idx, type_begin_idx, type_end_idx)
            name = get_name_from_line(line)
            type_begin_idx = idx
        elif is_simple_type_end(line):
            assert type_end_idx is None, 'Bad line: {} {} {}'.format(idx, type_begin_idx, type_end_idx)
            type_end_idx = idx

        if type_begin_idx is not None and type_end_idx is not None:
            return {'begin': type_begin_idx,
                    'end': type_end_idx,
                    'name': name}

    return None


def find_all_simple_types(diff):
    res = dict()
    start_idx = 0
    while start_idx < len(diff):
        next_type = find_next_simple_type(diff, start_idx)
        if next_type is None:
            break
        assert next_type['name'] not in res, 'Events with same name: {}'.format(next_type)
        res[next_type['name']] = next_type
        start_idx = next_type['end'] + 1
    return res


STRING_RESTRICTION_REGEXP = re.compile('<xs[d]?:restriction base="(.*?)">')


def has_restriction(line):
    base = STRING_RESTRICTION_REGEXP.findall(line)
    assert len(base) <= 1, 'Multiple base of event: {}'.format(line)
    if len(base) == 1:
        assert base[0] in ['xs:string', 'xsd:string', 'xs:unsignedInt', 'xsd:unsignedInt'], \
            'Base of enum must be string or unsignedInt: {}'.format(line)
        return True
    return False


def is_enum(diff, begin, end):
    was_restriction = False
    for idx in range(begin, end):
        line = diff[idx]
        if has_restriction(line):
            assert not was_restriction, 'Enum with multiple restrictions: {}'.format(line)
            was_restriction = True
    return was_restriction


ENUM_VALUE_REGEXP = re.compile('xs[d]?:enumeration value="(.*?)"')
MIN_INCLUSIVE_REGEXP = re.compile('xs[d]?:minInclusive value="(.*?)"')
MAX_INCLUSIVE_REGEXP = re.compile('xs[d]?:maxInclusive value="(.*?)"')


def get_enum_values(diff, begin, end):
    values = set()
    for idx in range(begin, end):
        line = diff[idx]
        for regexp in [ENUM_VALUE_REGEXP, MIN_INCLUSIVE_REGEXP, MAX_INCLUSIVE_REGEXP]:
            value = regexp.findall(line)
            assert len(value) <= 1, 'Multiple enum value: {}'.format(line)
            if len(value) == 1:
                assert value[0] not in values, 'Enum value dublication: {}'.format(value[0])
                values.add(value[0])
    return values


def validate_changes_in_simple_types(xsd_filename,
                                     stable_file, stable_simple_types,
                                     test_file, test_simple_types):
    for name, stable_type in stable_simple_types.iteritems():
        assert name in test_simple_types, 'Simple type deleted: {}'.format(name)
        test_type = test_simple_types[name]

        assert is_enum(stable_file, stable_type['begin'], stable_type['end']) == \
               is_enum(test_file, test_type['begin'], test_type['end']), \
               'Enum attribute changed: {}'.format(name)

        if is_enum(stable_file, stable_type['begin'], stable_type['end']):
            stable_enum_values = get_enum_values(stable_file, stable_type['begin'], stable_type['end'])
            test_enum_values = get_enum_values(test_file, test_type['begin'], test_type['end'])
            for value in stable_enum_values:
                assert value in test_enum_values, 'Enum value deleted: {} {}\nbefore={}\nafter={}'.format(name,
                                                                                                          value,
                                                                                                          stable_enum_values,
                                                                                                          test_enum_values)


def check_multiple_prefix_styles(data):
    was_xs_prefix = False
    was_xsd_prefix = False
    for line in data:
        if line.find('xs:') != -1:
            was_xs_prefix = True
        if line.find('xsd:') != -1:
            was_xsd_prefix = True
    multiple_prefix_styles = was_xs_prefix and was_xsd_prefix
    assert not multiple_prefix_styles, 'File contains xs: and xsd: at same time'


def validate_diff(xsd_filename, stable_file, test_file):
    check_multiple_prefix_styles(stable_file)
    check_multiple_prefix_styles(test_file)

    stable_complex_types_and_groups = find_all_complex_types_and_groups(stable_file)
    test_complex_types_and_groups = find_all_complex_types_and_groups(test_file)
    stable_groups = get_groups_from_stable_events(stable_complex_types_and_groups)

    validate_complex_types_and_groups(xsd_filename,
                                      stable_file, stable_complex_types_and_groups,
                                      test_file, test_complex_types_and_groups)
    validate_changes_in_complex_types_and_groups(xsd_filename,
                                                 stable_file, stable_complex_types_and_groups, stable_groups,
                                                 test_file, test_complex_types_and_groups)

    stable_simple_types = find_all_simple_types(stable_file)
    test_simple_types = find_all_simple_types(test_file)

    validate_changes_in_simple_types(xsd_filename, stable_file, stable_simple_types, test_file, test_simple_types)


class ScarabXsdValidate(sdk2.Task):
    class Parameters(sdk2.Parameters):

        arcadia_url = sdk2.parameters.ArcadiaUrl(required=True)
        arcadia_patch = sdk2.parameters.String('Patch for arcadia')

    def GetRevisonStrSafe(self, svn_url):
        revision = GetRevision(svn_url)

        if revision == "HEAD":
            raise common.errors.TaskFailure("Bad svn url, revision needed: {}".format(svn_url))

        return int(revision)

    def collect_xsd_files(self, xsd_dir):
        res = dict()
        for xsd_filename in os.listdir(xsd_dir):
            logging.info('ScarabXsdValidate: collect xsd {}'.format(xsd_filename))
            if not xsd_filename.endswith('.xsd'):
                continue

            xsd_filepath = os.path.join(xsd_dir, xsd_filename)
            res[xsd_filename] = read_file(xsd_filepath)
        return res

    def on_execute(self):
        logging.info('ScarabXsdValidate: Start task')
        arcadia_url = self.Parameters.arcadia_url
        arcadia_dir = sdk.do_clone(arcadia_url, self)
        ya_path = os.path.join(arcadia_dir, 'ya')
        self.execute('checkout', ya_path + ' make --checkout scarab/xsd -j0', cwd=arcadia_dir)
        xsd_dir = os.path.join(arcadia_dir, 'scarab', 'xsd')

        stable_xsd_files = self.collect_xsd_files(xsd_dir)

        arcadia_patch = self.Parameters.arcadia_patch
        logging.info('ScarabXsdValidate: arcadia_dir: {}, xsd_dir: {}, arcadia_patch: {} | bool: {}, path: {}, log_path: {}'.format(
                     arcadia_dir, xsd_dir, arcadia_patch, bool(arcadia_patch), self.path(), self.log_path()))
        if arcadia_patch:
            sdk2.vcs.svn.Arcadia.apply_patch(arcadia_dir, arcadia_patch, str(self.path()))
        logging.info('ScarabXsdValidate: After, arcadia_dir: {}, xsd_dir: {}, arcadia_patch: {} | bool: {}, path: {}, log_path: {}'.format(
                     arcadia_dir, xsd_dir, arcadia_patch, bool(arcadia_patch), self.path(), self.log_path()))

        logging.info('ScarabXsdValidate: xsd_dir: {}'.format(xsd_dir))

        test_xsd_files = self.collect_xsd_files(xsd_dir)

        for xsd_filename in stable_xsd_files:
            assert xsd_filename in test_xsd_files, 'Xsd file deleted: {}'.format(xsd_filename)

        for xsd_filename, test_xsd in test_xsd_files.iteritems():
            logging.info('ScarabXsdValidate: xsd_file {}'.format(xsd_filename))
            stable_xsd = ''
            if xsd_filename in stable_xsd_files:
                stable_xsd = stable_xsd_files[xsd_filename]

            try:
                validate_diff(xsd_filename, stable_xsd, test_xsd)
            except AssertionError:
                eh.check_failed('Xsd {} validate failed:\n{}'.format(xsd_filename, eh.shifted_traceback()))
            except Exception:
                eh.fail('Xsd {} validate failed:\n{}. Please write to logs-team@'.format(xsd_filename, eh.shifted_traceback()))

        logging.info('ScarabXsdValidate: Task finished')

    def execute(self, logger, command, cwd='.', env=None):
        if env is None:
            env = {}
        env.update(os.environ)
        env['TZ'] = 'Europe/Moscow'
        with ProcessLog(self, logger=logger) as pl:
            retcode = subprocess.Popen(command.split(), stdout=pl.stdout, stderr=subprocess.STDOUT, cwd=cwd, env=env).wait()
            if retcode == 0:
                return
            raise common.errors.TaskFailure('%s failed' % logger)
