"""
These are all of the tests for this repository.  I fully expect these to require
organic refactor as we add more requirements.
"""

import os
import re
import sys
import unittest

import build
import publish

# Pre-load all data to speed up testing performance
_types = build.list_types()
_fields = build.get_fields()
_groups = build.get_groups()
_events = build.get_events()
_transforms = build.list_transforms()
_expectations = build.list_expectations()
_deprecated = build.get_deprecated()
_exceptions = build.load_object(build.project_path() + '/scripts/exceptions.yaml')['exceptions']
# Generally these are the reserved words in Redshift
_reserved = {
    'aes128', 'aes256', 'all', 'allowoverwrite', 'analyse', 'analyze', 'and', 'any', 'array', 'asc',
    'authorization', 'backup', 'between', 'binary', 'blanksasnull', 'both', 'bytedict', 'bzip2',
    'case', 'cast', 'check', 'collate', 'column', 'constraint', 'create', 'credentials', 'cross',
    'default', 'deferrable', 'deflate', 'defrag', 'delta', 'delta32k', 'desc', 'disable',
    'distinct', 'else', 'emptyasnull', 'enable', 'encode', 'encrypt', 'encryption', 'end', 'except',
    'explicit', 'false', 'for', 'foreign', 'freeze', 'from', 'full', 'globaldict256',
    'globaldict64k', 'grant', 'group', 'gzip', 'having', 'identity', 'ignore', 'ilike', 'initially',
    'inner', 'intersect', 'into', 'isnull', 'join', 'leading', 'left', 'like', 'limit', 'localtime',
    'localtimestamp', 'lun', 'luns', 'lzo', 'lzop', 'minus', 'mostly13', 'mostly32', 'mostly8',
    'natural', 'new', 'not', 'notnull', 'null', 'nulls', 'off', 'offline', 'offset', 'oid', 'old',
    'only', 'open', 'order', 'outer', 'overlaps', 'parallel', 'partition', 'percent', 'permissions',
    'placing', 'primary', 'raw', 'readratio', 'recover', 'references', 'respect', 'rejectlog',
    'resort', 'restore', 'right', 'select', 'similar', 'snapshot', 'some', 'sysdate', 'system',
    'table', 'tag', 'tdes', 'text255', 'text32k', 'then', 'timestamp', 'top', 'trailing', 'true',
    'truncatecolumns', 'union', 'unique', 'user', 'using', 'verbose', 'wallet', 'when', 'where',
    'with', 'without'
}


# This context manager allows skipping an anonymous block
# https://code.google.com/archive/p/ouspg/wikis/AnonymousBlocksInPython.wiki
class SkipBlockContextManager(object):
    def __enter__(self):
        # This causes an exception to be raised on the first instruction of the block
        sys.settrace(lambda *args, **keys: None)
        frame = sys._getframe(1)
        frame.f_trace = self.trace

    def trace(self, frame, event, arg):
        raise

    def __exit__(self, type, value, traceback):
        return True


def exception_exists(suite, method, params):
    """Determine if a test has an exception and should be skipped"""
    for e in _exceptions:
        if e == dict({'suite': suite, 'method': method}, **params):
            return True
    return False


class BaseTestCase(unittest.TestCase):
    def subTest(self, msg=None, **params):
        """Extend subTest with the skipTest behavior we want"""
        # Unittest does not have a good way of skipping subTests, so this forces that behavior.
        # https://bugs.python.org/issue25894   https://bugs.python.org/issue35327
        if exception_exists(type(self).__name__, self._testMethodName, params):
            return SkipBlockContextManager()
        return super().subTest(msg=msg, **params)

    def _check_keys(self, obj, supported, obj_type):
        for key in obj.keys():
            self.assertIn(key, supported, f"'{key}' is not a supported field in {obj_type}")

    def _check_type(self, t):
        self.assertIn('name', t, 'types need names')
        self.assertIn(t['name'], _types, 'type not found')
        if t['name'] in ('long', 'bool', 'float'):
            self._check_keys(t, {'name'}, f'type({t["name"]})')
        elif t['name'] == 'int':
            self.fail('int type is deprecated')
        elif t['name'] == 'string':
            self._check_keys(t, {'name', 'length'}, f'type({t["name"]})')
            self.assertIn('length', t, 'strings require a length')
            self.assertIsInstance(t['length'], int, 'length must be an int')
            self.assertLessEqual(t['length'], 65535, 'length must less than 65535')
        elif t['name'] == 'enum':
            self._check_keys(t, {'name', 'values'}, f'type({t["name"]})')
            self.assertIn('values', t, 'enums need values')
            self.assertIsInstance(t['values'], list, 'values must be a list')
            self.assertGreater(len(t['values']), 0, 'enum type needs at least 1 value')
            self.assertGreaterEqual(100, len(t['values']), 'enum supports 100 values')
            for val in t['values']:
                self.assertIsInstance(val, str, 'enum values must be strings')
        elif t['name'] == 'timestamp':
            self._check_keys(t, {'name', 'timezone'}, f'type({t["name"]})')
            self.assertIn('timezone', t, 'timezones require a timezone')
            timezones = ['UTC', 'America/Los_Angeles']
            self.assertIn(t['timezone'], timezones, 'timezone must be in {timezones}')
        else:
            self.fail(f'unknown type {t["name"]}')

    def _check_transform(self, t):
        self.assertIn('name', t, 'transforms need names')
        self.assertIn(t['name'], _transforms, 'transform not found')
        self.assertIn('source', t, 'transforms need a source')
        self.assertIsInstance(t['source'], str, 'source must be string')

    def _check_expectation(self, e, field):
        self.assertIn('name', e, 'expectations need names')
        self.assertIn(e['name'], _expectations, 'expectation not found')
        if e['name'] == 'value_lengths_to_be_between':
            supported = {'string'}
            self._check_keys(e, {'name', 'min', 'max'}, f'expectation({e["name"]})')
            self.assertIsInstance(e['min'], int, 'min must be an int')
            self.assertGreaterEqual(e['min'], 0, 'min must be >= 0')
            self.assertIsInstance(e['max'], int, 'max must be an int')
            self.assertGreater(e['max'], 0, 'max must be >= 1')
            self.assertGreaterEqual(e['max'], e['min'], 'max must be >= min')
        elif e['name'] == 'values_to_be_between':
            supported = {'long', 'float', 'string', 'timestamp'}
            self._check_keys(e, {'name', 'min', 'max'}, f'expectation({e["name"]})')
            self.assertIsInstance(e['min'], int, 'min must be an int')
            self.assertIsInstance(e['max'], int, 'max must be an int')
            self.assertGreaterEqual(e['max'], e['min'], 'max must be >= min')
            self.assertIsInstance(e['max'], type(e['min']), 'min and max must be same type')
        elif e['name'] == 'values_to_be_in_set':
            supported = {'long', 'bool', 'float', 'string', 'timestamp'}
            self._check_keys(e, {'name', 'values'}, f'expectation({e["name"]})')
            self.assertIsInstance(e['values'], list, 'values must be a list')
            self.assertGreater(len(e['values']), 0, 'expectation needs at least 1 value')
            self.assertGreaterEqual(100, len(e['values']), 'expectation supports 100 values')
            for idx in range(len(e['values'])-1):
                self.assertIsInstance(e['values'][idx], type(e['values'][idx+1]),
                                      'values need to be of same type')
        elif e['name'] == 'values_to_not_be_null':
            supported = set(_types)  # all types are supported
            self._check_keys(e, {'name'}, f'expectation({e["name"]})')
        elif e['name'] == 'values_to_match_regex_list':
            supported = {'string'}
            self._check_keys(e, {'name', 'patterns'}, f'expectation({e["name"]})')
            self.assertIsInstance(e['patterns'], list, 'patterns must be a list')
            self.assertGreater(len(e['patterns']), 0, 'expectation needs at least 1 pattern')
            self.assertGreaterEqual(3, len(e['patterns']), 'expectation supports 3 patterns')
            for val in e['patterns']:
                self.assertIsInstance(val, str, 'patterns must be strings')
                re.compile(val)
        else:
            self.fail('unknown expectation type')
        self.assertNotIn('transform', field, 'expectations for transforms are not supported')
        self.assertIn(field['type']['name'], supported, f'expectation only for {supported}')

    def _check_override(self, name, o, field):
        self._check_keys(o, {'name', 'type', 'transform', 'expectations', 'source'}, 'override')
        if 'type' in o or 'transform' in o:
            k = 'type' if 'type' in o else 'transform'
            self.assertNotEqual('type' in o, 'transform' in o, 'can have type xor transform')
            self.assertIn('source', o[k], 'can only override source for type')
            self.assertIsInstance(o[k]['source'], str, 'source must be string')
            self.assertNotEqual(o[k]['source'], o['name'], 'source cannot be same as name')
        if 'expectations' in o:
            self.assertIsInstance(o['expectations'], list, 'expectations must be a list')
            for e in o['expectations']:
                with self.subTest(event=name, override=o['name'], expectation=e['name']):
                    self._check_expectation(e, field)


class TestCore(BaseTestCase):
    def test_no_filename_conflicts(self):
        tcs = [
            ['fields', sorted(_fields.keys())],
            ['groups', sorted(_groups.keys())],
            ['events', sorted(_events.keys())],
        ]
        for obj_name, names in tcs:
            for idx, n1 in enumerate(names[:-1]):
                for n2 in names[idx+1:]:
                    with self.subTest(object=obj_name, filename=n1, other_filename=n2):
                        self.assertNotEqual(os.path.basename(n1), os.path.basename(n2),
                                            'filename objects must be unique')

    def test_is_object_and_name_matches_filename(self):
        tcs = [
            ['fields', sorted(_fields.items())],
            ['groups', sorted(_groups.items())],
            ['events', sorted(_events.items())],
        ]
        for obj_name, objs in tcs:
            for filename, obj in objs:
                with self.subTest(object=obj_name, filename=filename):
                    self.assertIsInstance(obj, dict, 'yaml objects must be dictionaries')
                    self.assertEqual(len(obj), 1, 'each file must only have one element')
                    name = os.path.basename(filename).split('.yaml')[0]
                    self.assertIn(name, obj, 'each object name must match its filename')

    def test_no_name_mismatches(self):
        tcs = [
            ['fields', _fields],
            ['groups', _groups],
            ['events', _events],
        ]
        for obj_name, objs in tcs:
            with self.subTest(object=obj_name):
                self.assertEqual(len(objs), len(build.objects_by_name(objs.values())),
                                 'objects need unique names')

    def test_objects_have_description(self):
        tcs = [
            ['fields', build.objects_by_name(_fields.values())],
            ['groups', build.objects_by_name(_groups.values())],
            ['events', build.objects_by_name(_events.values())],
        ]
        for obj_name, objs in tcs:
            for name, contents in objs.items():
                with self.subTest(object=obj_name, name=name):
                    self.assertIn('description', contents, 'all objects need a description')

    def test_object_description_minimum_length(self):
        tcs = [
            ['fields', build.objects_by_name(_fields.values())],
            ['groups', build.objects_by_name(_groups.values())],
            ['events', build.objects_by_name(_events.values())],
        ]
        for obj_name, objs in tcs:
            for name, contents in objs.items():
                if obj_name == 'fields':
                    with self.subTest(field=name):
                        self.assertGreater(len(contents['description']), 100,
                                        'your description is not long enough')
                if obj_name == 'groups':
                    with self.subTest(group=name):
                        self.assertGreater(len(contents['description']), 100,
                                        'your description is not long enough')
                if obj_name == 'events':
                    with self.subTest(event=name):
                        self.assertGreater(len(contents['description']), 100,
                                        'your description is not long enough')

    def test_not_deprecated(self):
        # int is deprecated but we still have a type page
        # and allow users to add events with int fields
        # because migrating away from it is currently
        # more work than just temporarily supporting it.
        types = [type for type in _types if type != 'int']
        tcs = [
            ['events', build.objects_by_name(_events.values()).keys()],
            ['expectations', _expectations],
            ['fields', build.objects_by_name(_fields.values()).keys()],
            ['groups', build.objects_by_name(_groups.values()).keys()],
            ['transforms', _transforms],
            ['types', types],
        ]
        for obj_name, objs in tcs:
            for dep in _deprecated.get(obj_name, []):
                with self.subTest(object=obj_name, name=dep):
                    self.assertNotIn(dep, objs, 'deprecated')


class TestFields(BaseTestCase):
    def setUp(self):
        self.fields = build.objects_by_name(_fields.values())

    def test_field_names_conform(self):
        name_re = re.compile('^[a-z][a-z0-9_]{1,126}$')
        for name in self.fields:
            with self.subTest(field=name):
                self.assertRegex(name, name_re, f"field name doesn't match {name_re}")
                self.assertNotIn(name, _reserved, 'field name is reserved name')
                self.assertNotIn('tahoe', name, 'field name contains tahoe')
                self.assertNotIn('spade', name, 'field name contains spade')

    def test_fields_have_only_known_keys(self):
        supported = {'description', 'transform', 'type', 'internal', 'sensitivity', 'expectations'}
        for name, field in self.fields.items():
            with self.subTest(field=name):
                self._check_keys(field, supported, 'field')

    def test_fields_have_transform_xor_type(self):
        for name, field in self.fields.items():
            with self.subTest(field=name):
                self.assertNotEqual('transform' in field, 'type' in field,
                                    'fields must have a transform xor a type')

    def test_fields_sensitivity(self):
        sensitivities = ['ip', 'userid', 'otherid', 'sessionid', 'none']
        for name, field in self.fields.items():
            with self.subTest(field=name):
                self.assertIn('sensitivity', field, 'fields must have sensitivities')
                self.assertIn(field['sensitivity'], sensitivities, 'sensitivity not valid')

    def test_fields_valid_internal(self):
        for name, field in self.fields.items():
            if 'internal' in field:
                with self.subTest(field=name):
                    self.assertTrue(field['internal'], 'internal can only be true')

    def test_field_type_ref(self):
        for name, field in self.fields.items():
            if 'type' in field:
                with self.subTest(field=name):
                    self._check_type(field['type'])

    def test_field_transform_ref(self):
        for name, field in self.fields.items():
            if 'transform' in field:
                with self.subTest(field=name):
                    self._check_transform(field['transform'])

    def test_field_expectations(self):
        for name, field in self.fields.items():
            if 'expectations' in field:
                with self.subTest(field=name):
                    for expectation in field['expectations']:
                        self.assertIn('name', expectation, 'expectations need names')
                        with self.subTest(field=name, expectation=expectation['name']):
                            self._check_expectation(expectation, field)


class TestGroups(BaseTestCase):
    def setUp(self):
        self.groups = build.objects_by_name(_groups.values())

    def test_groups_have_only_known_keys(self):
        supported = {'description', 'fields', 'groups'}
        for name, group in self.groups.items():
            with self.subTest(group=name):
                self._check_keys(group, supported, 'group')

    def test_group_has_at_least_two_members(self):
        for name, group in self.groups.items():
            with self.subTest(group=name):
                count = len(group.get('fields', [])) + len(group.get('groups', []))
                self.assertGreater(count, 1, 'groups must have more than one fields and groups')

    def test_group_fields_sorted(self):
        for name, group in self.groups.items():
            with self.subTest(group=name):
                self.assertEqual(group.get('fields', []), sorted(group.get('fields', [])),
                                 'field values must be sorted')

    def test_group_groups_sorted(self):
        for name, group in self.groups.items():
            with self.subTest(group=name):
                self.assertEqual(group.get('groups', []), sorted(group.get('groups', [])),
                                 'group values must be sorted')

    def test_group_field_refs(self):
        field_names = set(build.objects_by_name(_fields.values()).keys())
        for name, group in self.groups.items():
            for field in group.get('fields', []):
                with self.subTest(group=name, field=field):
                    self.assertIn(field, field_names, 'field reference not found')

    def test_group_group_refs(self):
        group_names = set(build.objects_by_name(_groups.values()).keys())
        for name, group in self.groups.items():
            for group in group.get('groups', []):
                with self.subTest(group=name, group_ref=group):
                    self.assertIn(group, group_names, 'group reference not found')

    def test_group_refs_acyclic(self):
        for name, group in self.groups.items():
            with self.subTest(group=name):
                inner = group.get('groups', [])
                while inner:
                    g = inner.pop()
                    self.assertNotEqual(name, g, 'groups must be acyclic')
                    inner.extend(self.groups[g].get('groups', []))


class TestEvents(BaseTestCase):
    def setUp(self):
        self.events = build.objects_by_name(_events.values())

    def test_event_names_conform(self):
        name_re = re.compile('^[a-z][a-z0-9_]{2,126}$')
        invalid_prefix = 'pg_'
        for name in self.events:
            with self.subTest(event=name):
                self.assertRegex(name, name_re, f"event name doesn't match {name_re}")
                self.assertFalse(name.startswith(invalid_prefix), 'event name starts with pg_')
                self.assertNotIn(name, _reserved, 'event name is reserved name')
                self.assertNotIn('tahoe', name, 'event name contains tahoe')
                self.assertNotIn('spade', name, 'event name contains spade')

    def test_events_have_only_known_keys(self):
        supported = {'description', 'fields', 'groups', 'overrides'}
        for name, event in self.events.items():
            with self.subTest(event=name):
                self._check_keys(event, supported, 'event')

    def test_events_have_fields_or_groups(self):
        for name, event in self.events.items():
            with self.subTest(event=name):
                self.assertTrue('fields' in event or 'groups' in event,
                                'events must have fields or groups')

    def test_event_fields_sorted(self):
        for name, event in self.events.items():
            with self.subTest(event=name):
                self.assertEqual(event.get('fields', []), sorted(event.get('fields', [])),
                                 'field values must be sorted')

    def test_event_groups_sorted(self):
        for name, event in self.events.items():
            with self.subTest(event=name):
                self.assertEqual(event.get('groups', []), sorted(event.get('groups', [])),
                                 'group values must be sorted')

    def test_event_field_refs(self):
        field_names = set(build.objects_by_name(_fields.values()).keys())
        for name, event in self.events.items():
            for field in event.get('fields', []):
                with self.subTest(event=name, field=field):
                    self.assertIn(field, field_names, 'field reference not found')

    def test_event_group_refs(self):
        group_names = set(build.objects_by_name(_groups.values()).keys())
        for name, event in self.events.items():
            for group in event.get('groups', []):
                with self.subTest(event=name, group=group):
                    self.assertIn(group, group_names, 'group reference not found')

    def test_overrides(self):
        for name, event in self.events.items():
            if 'overrides' in event:
                all_fields = self._resolve_all_fields(event)
            for override in event.get('overrides', []):
                with self.subTest(event=name):
                    self.assertIn('name', override, 'override must have a name')
                    self.assertIn(override['name'], all_fields, 'name must be a valid name')
                    with self.subTest(event=name, override=override['name']):
                        self._check_override(name, override, all_fields[override['name']])

    def _resolve_all_fields(self, event):
        all_groups = build.objects_by_name(_groups.values())
        all_fields = build.objects_by_name(_fields.values())
        fields = event.get('fields', [])
        groups = event.get('groups', [])
        while groups:
            g = all_groups[groups.pop()]
            fields.extend(g.get('fields', []))
            groups.extend(g.get('groups', []))
        return {k: all_fields[k] for k in fields}


class TestCodeowners(unittest.TestCase):
    def _check_object_expectations(self, owners, lines, prefix, objects):
        self.assertEqual(lines, sorted(lines), 'codeowners for objects must be sorted')
        for idx, fname in enumerate(sorted(objects.keys())):
            self.assertIn(prefix+fname, lines[idx], 'object not covered by codeowners')
            self.assertIn(lines[idx].split()[1], owners)

    def test_codeowners(self):
        with open(build.project_path('.github') + '/CODEOWNERS') as fp:
            lines = [x for x in fp]

        owners = {x.split()[1] for x in lines if x.startswith('## @spade')}

        # We want to make sure a specific team is never listed as a code owner
        self.assertNotIn('@spade/bp-editors-for-org-membership-not-for-event-ownership', owners)

        # We only want to allow the demo team in the playground repo.
        if not build.is_playground():
            owners -= {'@spade/demo'}

        with self.subTest(object_type='event'):
            to_check = build.parse_lines_between(lines, '# Event ownership.', '# End Event')
            self._check_object_expectations(owners, to_check, '/events/', _events)

        with self.subTest(object_type='field'):
            to_check = build.parse_lines_between(lines, '# Field ownership.', '# End Field')
            self._check_object_expectations(owners, to_check, '/fields/', _fields)

        with self.subTest(object_type='group'):
            to_check = build.parse_lines_between(lines, '# Group ownership.', '# End Group')
            self._check_object_expectations(owners, to_check, '/groups/', _groups)


class TestPublishing(unittest.TestCase):
    def test_endstate_generation(self):
        events = {'invariant_ping': {
                     'description': 'N/A',
                     'fields': ['origin', 'sequence_number', 'city', 'query_granularity'],
                     'groups': ['internal_ip'],
                     'overrides': [
                         {'name': 'sequence_number', 'expectations': []},
                         {'name': 'query_granularity', 'source': 'granularity'}
                        ]}
                 }
        groups = build.objects_by_name(_groups.values())
        fields = build.objects_by_name(_fields.values())
        desired_state = publish.convert_events_to_desired_state(events, groups, fields, [])
        self.assertEqual(len(desired_state['eventEndStates']), 1)
        desired_events = desired_state['eventEndStates']
        self.assertEqual(desired_events[0]['eventName'], 'invariant_ping')
        self.assertEqual(len(desired_events[0]['columns']), 7)
        desired_columns = desired_events[0]['columns']
        self.assertIn({
            'inboundName': 'origin',
            'outboundName': 'origin',
            'transformer': 'varchar',
            'columnCreationOptions': '(8)',
            'sensitivityType': None,
            }, desired_columns)
        self.assertIn({
            'inboundName': 'sequence_number',
            'outboundName': 'sequence_number',
            'transformer': 'bigint',
            'columnCreationOptions': '',
            'sensitivityType': None,
        }, desired_columns)
        self.assertIn({
            'inboundName': 'time',
            'outboundName': 'time',
            'transformer': 'f@timestamp@unix',
            'columnCreationOptions': '',
            'sensitivityType': None,
        }, desired_columns)
        self.assertIn({
            'inboundName': 'time',
            'outboundName': 'time_utc',
            'transformer': 'f@timestamp@unix-utc',
            'columnCreationOptions': '',
            'sensitivityType': None,
        }, desired_columns)
        self.assertIn({
            'inboundName': 'ip',
            'outboundName': 'ip',
            'transformer': 'varchar',
            'columnCreationOptions': '(15)',
            'sensitivityType': 'ip',
        }, desired_columns)
        self.assertIn({
            'inboundName': 'ip',
            'outboundName': 'city',
            'transformer': 'ipCity',
            'columnCreationOptions': '',
            # transforms inherit the sensitivity of source field
            'sensitivityType': None,
        }, desired_columns)
        self.assertIn({
            'inboundName': 'granularity',
            'outboundName': 'query_granularity',
            'transformer': 'varchar',
            'columnCreationOptions': '(64)',
            'sensitivityType': None,
        }, desired_columns)

        events = {}
        desired_state = publish.convert_events_to_desired_state(
                                events, groups, fields,
                                ['invariant_ping'])
        self.assertEqual(len(desired_state['eventEndStates']), 1)
        desired_events = desired_state['eventEndStates']
        self.assertEqual(desired_events[0]['eventName'], 'invariant_ping')
        self.assertEqual(len(desired_events[0]['columns']), 0)

    def test_make_codegen_schema(self):
        events = {'invariant_ping': {
                     'description': 'N/A',
                     'fields': ['origin', 'sequence_number', 'city', 'login'],
                     'groups': ['internal_ip'],
                     'overrides': [{'name': 'sequence_number', 'expectations': []}]}
                 }
        groups = build.objects_by_name(_groups.values())
        fields = build.objects_by_name(_fields.values())
        added_exp = {'name': 'value_lengths_to_be_between', 'min': 0,
                     'max': fields['login']['type']['length']}
        expected = {'events': [
                       {'name': 'invariant_ping', 'description': 'N/A', 'fields': [
                           {'name': 'login',
                            'description': fields['login']['description'],
                            'type': {'name': 'string'},
                            'expectations': fields['login'].get('expectations', []) + [added_exp]},
                           {'name': 'origin',
                            'description': fields['origin']['description'],
                            'type': fields['origin']['type'],
                            'expectations': fields['origin'].get('expectations', [])},
                           {'name': 'sequence_number',
                            'description': fields['sequence_number']['description'],
                            'type': fields['sequence_number']['type'],
                            'expectations': []}
                       ]}
                   ]}
        results =  build.make_codegen_schema(events, groups, fields)
        self.assertEqual(expected, build.make_codegen_schema(events, groups, fields))
