import datetime
import json
import logging
import time

METRIKA_EVENTS = {
    'ios': {
        'signature': 'settings-tap-signature',
        'dark_theme': 'settings_dark_theme_turn_on'
    },
    'android': {
        'signature': 'account_settings_tap_signature_place',
        'dark_theme': 'dark_theme.was_triggered'
    }
}


class UserEventsDBYT:
    def __init__(self, client, yt_paths):
        self.client = client
        self.yt_paths = yt_paths
        self.logger = logging.getLogger("global")

    def get_fresh_installs(self):
        from_date = (datetime.date.today()-datetime.timedelta(days=30)).strftime("%Y-%m-%d")
        to_date = (datetime.date.today()-datetime.timedelta(days=2)).strftime("%Y-%m-%d")
        prev_date = (datetime.date.today()-datetime.timedelta(days=1)).strftime("%Y-%m-%d")
        query = """
            SELECT
                EventID, UUID as uuid, DeviceID as device_id, DeviceIDHash, AccountID as uid, ReceiveDate,
                RegionID as region_id, CAST(RegionTimeZone as int64) as timezone, CAST(timestamp as Timestamp) as timestamp,
                ParsedParams_Key1, ParsedParams_Key2, OperatingSystem as os, OriginalManufacturer as manufacturer, OriginalModel as model, OSVersion as os_version,
                OSApiLevel as os_api_level, Locale as locale, DeviceType as device_type, AppID as app_id, AppPlatform as platform, AppVersionName as app_version,
                AppBuildNumber as app_build, ReceiveTimestamp
            FROM
                REGEXP([{metrika_1d}], '{prev_date}')
            WHERE
                EventID IN (
                    SELECT MIN(EventID) AS EventID
                    FROM REGEXP([{metrika_1d}], '{prev_date}')
                    WHERE
                        (APIKey = '29733' OR APIKey = '14836')
                        AND ((EventName = 'first-launch-after-update' AND ParsedParams_Key2 like '%fresh-install%')
                            OR EventType = 'EVENT_INIT'
                            OR EventType = 'EVENT_FIRST')
                        AND DeviceID NOT IN (
                            SELECT DeviceID
                            FROM RANGE([{metrika_1d}], '{check_from_date}', '{check_to_date}')
                            WHERE (APIKey = '29733' OR APIKey = '14836'))
                    GROUP BY DeviceID)
                AND CAST(RegionTimeZone as int64) IS NOT NULL
                AND Locale IS NOT NULL;
        """.format(metrika_1d=self.yt_paths.METRIKA_1D, check_from_date=from_date, check_to_date=to_date, prev_date=prev_date)
        table = self._fetch(query, 'get_fresh_installs')
        devices = []
        processed_devices = set()
        for row in table.rows:
            entry = {}
            for key, value in zip(table.column_names, row):
                if type(value) == str or type(value) == unicode:
                    entry[key] = filter(lambda c: ord(c) < 128, value)
                else:
                    entry[key] = value
            if 'platform' in entry:
                entry['platform'] = entry['platform'].lower()
            if entry['device_id'] not in processed_devices:
                devices.append(entry)
                processed_devices.add(entry['device_id'])
        return devices

    def check_relevance(self, notifications):
        check_by_metrika_log = []
        check_by_xeno_log = []
        for notification in notifications:
            if 'step' not in notification:
                continue
            if notification['step'] == 'xeno':
                check_by_xeno_log.append(notification)
            elif notification['platform'] in METRIKA_EVENTS and notification['step'] in METRIKA_EVENTS[notification['platform']]:
                check_by_metrika_log.append(notification)
        irrelevant_devices = self._check_relevance_by_metrika_log(check_by_metrika_log)
        irrelevant_devices = irrelevant_devices.union(self._check_relevance_by_xeno_log(check_by_xeno_log))
        return irrelevant_devices

    def get_uids_for_new_devices(self, devices):
        uids_from_account_id = self._get_uids_from_account_id([device['device_id'] for device in devices])
        uids_from_report_enviroment = self._get_uids_from_report_enviroment([device['device_id'] for device in devices])
        user_device_map = dict((row['device_id'], row['uid']) for row in uids_from_account_id + uids_from_report_enviroment)
        for device in devices:
            if device['uid'] is None and device['device_id'] in user_device_map:
                device['uid'] = user_device_map[device['device_id']]

    def get_new_uids_for_existing_devices(self):
        # FIXME Android devices may not have uid in AccountId
        prev_date = (datetime.date.today()-datetime.timedelta(days=1)).strftime("%Y-%m-%d")
        query = """
            SELECT DeviceID, AccountID
            FROM REGEXP([{metrika_1d}], '{prev_date}')
            WHERE
                (APIKey = '29733' OR APIKey = '14836')
                AND AccountID IS NOT NULL
                AND (DeviceID IN (SELECT device_id FROM [{devices}])
                    OR AccountID IN (SELECT uid FROM [{users}]))
                AND (DeviceID, AccountID) NOT IN (SELECT (device_id, uid) AS key FROM [{user_device}])
            GROUP BY DeviceID, AccountID
        """.format(
            metrika_1d=self.yt_paths.METRIKA_1D,
            prev_date=prev_date,
            users=self.yt_paths.USERS,
            devices=self.yt_paths.DEVICES,
            user_device=self.yt_paths.USER_DEVICE)
        table = self._fetch(query, 'get_new_uids_for_existing_devices')
        ret = []
        processed_user_device = set()
        for row in table.rows:
            entry = {}
            for key, value in zip(table.column_names, row):
                entry[key] = value
            if (entry['DeviceID'], entry['AccountID']) not in processed_user_device:
                ret.append({'device_id': entry['DeviceID'], 'uid': entry['AccountID']})
                processed_user_device.add((entry['DeviceID'], entry['AccountID']))
        return ret

    def write_new_devices(self, devices):
        self._write_user_device(devices)
        self._write_devices(devices)
        self._write_users(devices)

    def write_uids_for_existing_devices(self, user_device):
        self._write_user_device(user_device)

    def _get_uids_from_report_enviroment(self, devices):
        # Android devices do not log uid into AccountID field. Instead they write uid at ReportEnvironment field.
        prev_date = (datetime.date.today()-datetime.timedelta(days=1)).strftime("%Y-%m-%d")
        query = """
            SELECT EventID, DeviceID, ReportEnvironment_Keys, ReportEnvironment_Values
            FROM REGEXP([{metrika_1d}], '{prev_date}')
            WHERE EventID IN (
                SELECT min(EventID) AS EventID
                FROM REGEXP([{metrika_1d}], '{prev_date}')
                WHERE
                    (APIKey = '29733' OR APIKey = '14836')
                    AND ReportEnvironment_Keys like '%uid%'
                    AND DeviceID IN ({device_ids})
                GROUP BY DeviceID)
        """.format(
            metrika_1d=self.yt_paths.METRIKA_1D,
            prev_date=prev_date,
            device_ids=",".join(["'{}'".format(device) for device in devices]))
        table = self._fetch(query, 'get_uids_from_report_enviroment')
        ret = []
        processed_user_device = set()
        for row in table.rows:
            entry = {}
            for key, value in zip(table.column_names, row):
                entry[key] = value
            report_keys = json.loads(entry['ReportEnvironment_Keys'])
            report_values = json.loads(entry['ReportEnvironment_Values'])
            uid = ''
            for key, value in zip(report_keys, report_values):
                if key == 'uid':
                    uid = value
                    break
            if uid != '' and (entry['DeviceID'], uid) not in processed_user_device:
                ret.append({'device_id': entry['DeviceID'], 'uid': uid})
                processed_user_device.add((entry['DeviceID'], uid))
        return ret

    def _get_uids_from_account_id(self, new_devices):
        prev_date = (datetime.date.today()-datetime.timedelta(days=1)).strftime("%Y-%m-%d")
        query = """
            SELECT DeviceID, AccountID
            FROM REGEXP([{metrika_1d}], '{prev_date}')
            WHERE
                (APIKey = '29733' OR APIKey = '14836')
                AND AccountID IS NOT NULL
                AND DeviceID IN ({new_devices})
                AND (DeviceID, AccountID) NOT IN (SELECT (device_id, uid) AS key FROM [{user_device}])
            GROUP BY DeviceID, AccountID
        """.format(
            metrika_1d=self.yt_paths.METRIKA_1D,
            prev_date=prev_date,
            user_device=self.yt_paths.USER_DEVICE,
            new_devices=",".join(["'{}'".format(device) for device in new_devices]))
        table = self._fetch(query, 'get_uids_from_account_id')
        ret = []
        processed_user_device = set()
        for row in table.rows:
            entry = {}
            for key, value in zip(table.column_names, row):
                entry[key] = value
            if (entry['DeviceID'], entry['AccountID']) not in processed_user_device:
                ret.append({'device_id': entry['DeviceID'], 'uid': entry['AccountID']})
                processed_user_device.add((entry['DeviceID'], entry['AccountID']))
        return ret

    def _write_devices(self, devices):
        if len(devices) == 0:
            return
        fields = [
            'uuid', 'device_id', 'region_id', 'timezone',
            'os', 'os_version', 'os_api_level', 'manufacturer', 'model',
            'locale', 'device_type', 'app_id', 'platform', 'app_version', 'app_build']
        entries = []
        for device in devices:
            values = []
            for field in fields:
                if type(device[field]) == str or type(device[field]) == unicode:
                    values.append("'{}'".format(device[field]))
                elif device[field] is None:
                    values.append('null')
                else:
                    values.append(str(device[field]))
            entries.append("({})".format(",".join(values)))
        query = """
            INSERT INTO [{devices}] ({fields})
            VALUES {entries};
        """.format(devices=self.yt_paths.DEVICES, fields=", ".join(fields), entries="{}".format(",".join(entries)))
        self._execute(query, 'write_devices')

    def _write_users(self, users):
        if len(users) == 0:
            return
        entries = ",".join("('{}', CurrentUTCTimestamp())".format(
            user['uid']) for user in users)
        query = """
            INSERT INTO [{users}] (uid, mining_ts)
            VALUES {entries};
        """.format(users=self.yt_paths.USERS, entries=entries)
        self._execute(query, 'write_users')

    def _write_user_device(self, authorized_devices):
        if len(authorized_devices) == 0:
            return
        rows = ",".join("('{}', '{}')".format(
            entry['uid'], entry['device_id']) for entry in authorized_devices)
        query = """
            INSERT INTO [{user_device}] (uid, device_id)
            VALUES {rows};
        """.format(user_device=self.yt_paths.USER_DEVICE, rows=rows)
        self._execute(query, 'write_user_device')

    def _check_relevance_by_metrika_log(self, notifications):
        if len(notifications) == 0:
            return set()
        entries = ",".join(["('{}','{}')".format(
            notification['device_id'],
            METRIKA_EVENTS[notification['platform']][notification['step']]) for notification in notifications])
        prequery = """
            SELECT DeviceID
            FROM {tables}
            WHERE (APIKey = '29733' OR APIKey = '14836') AND (DeviceID, EventName) IN ({entries})
        """
        from_date = (datetime.date.today()-datetime.timedelta(days=4)).strftime("%Y-%m-%d")
        to_date = (datetime.date.today()-datetime.timedelta(days=1)).strftime("%Y-%m-%d")
        from_1d = "RANGE([{}], '{}', '{}')".format(self.yt_paths.METRIKA_1D, from_date, to_date)
        from_5m = 'RANGE([{}])'.format(self.yt_paths.METRIKA_5M)
        query = "{} UNION ALL {}".format(
            prequery.format(tables=from_1d, entries=entries),
            prequery.format(tables=from_5m, entries=entries))
        table = self._fetch(query, 'check_relevance_by_metrika_log')
        finished_users = set()
        for row in table.rows:
            finished_users.add(row[0])
        return finished_users

    def _check_relevance_by_xeno_log(self, notifications):
        if len(notifications) == 0:
            return set()
        entries = ",".join(["'{}'".format(notification['device_id']) for notification in notifications])
        prequery = """
            SELECT device_id
            FROM
                {tables} AS xeno
            JOIN
                (SELECT * FROM [{user_device}] WHERE device_id IN ({entries})) AS user_device
            ON xeno.uid = user_device.uid
        """
        from_date = (datetime.date.today()-datetime.timedelta(days=4)).strftime("%Y-%m-%d")
        to_date = (datetime.date.today()-datetime.timedelta(days=1)).strftime("%Y-%m-%d")
        from_1d = "RANGE([{}], '{}', '{}')".format(self.yt_paths.XENO_1D, from_date, to_date)
        from_5m = 'RANGE([{}])'.format(self.yt_paths.XENO_5M)
        query = "{} UNION ALL {}".format(
            prequery.format(tables=from_1d, user_device=self.yt_paths.USER_DEVICE, entries=entries),
            prequery.format(tables=from_5m, user_device=self.yt_paths.USER_DEVICE, entries=entries))
        table = self._fetch(query, 'check_relevance_by_xeno_log')
        finished_users = set()
        for row in table.rows:
            finished_users.add(row[0])
        return finished_users

    def _execute(self, query, name):
        start_time = time.time()
        request = self.client.query(query=query)
        request.run()
        request.get_results()
        elapsed_time = time.time() - start_time
        self.logger.info('run query name={name} time={time}'.format(name=name, time=elapsed_time))

    def _fetch(self, query, name):
        start_time = time.time()
        request = self.client.query(query=query)
        request.run()
        results = request.get_results()
        table = [t for t in results][0]
        table.fetch_full_data()
        elapsed_time = time.time() - start_time
        self.logger.info('run query name={name} time={time}'.format(name=name, time=elapsed_time))
        return table


class FakeUserEventsDB:
    def __init__(self):
        self.irrelevant_devices = []
        self.new_uids_for_existing_devices = []
        self.uids_for_new_devices = dict()
        self.fresh_installs = []
        self.written_uids_for_existing_devices = []
        self.written_devices = []

    def check_relevance(self, notifications):
        return self.irrelevant_devices

    def get_new_uids_for_existing_devices(self):
        return self.new_uids_for_existing_devices

    def get_fresh_installs(self):
        return self.fresh_installs

    def get_uids_for_new_devices(self, devices):
        for device in devices:
            if device['device_id'] in self.uids_for_new_devices:
                device['uid'] = self.uids_for_new_devices[device['device_id']]

    def write_uids_for_existing_devices(self, user_device):
        self.written_uids_for_existing_devices += user_device

    def write_new_devices(self, devices):
        self.written_devices += devices
