# -*- coding: utf-8 -*-
import argparse
import itertools
import json
import os
import pickle
import re
import sys

from nile.api.v1 import (
    Record,
)
from nile.api.v1 import (
    aggregators as na
)
from nile.api.v1 import clusters
from nile.api.v1 import filters as nf
from qb2.api.v1 import filters as qf


class MetrikaLogProcessor:
    """
        Класс отвечает за фильтрацию и преобразование событий из лога апп метрики.
        Если какие-то события кажутся мусорными, то их можно убрать здесь.
        C ignored_events в конструкторе надо быть осторожнее, там проверяется не строгое равенство, а вхождение
    """

    def __init__(self, supported_api_keys, supported_app_versions, ignored_events=None):
        if ignored_events is None:
            ignored_events = [
                'DEBUG EVENT',
                'mail_dns_resolver',
                'serialize',
                'Perf',
                'push',
                'Push',
                'ads',
                'am_uid_ok',
                'google_play_services_version',
                'subscribe',
                'network',
                'dynametric_web_view_event/MailActivity/react_web_view/DOMReady',
                'message_load_time_offline',
                'task_finished_ok',
                'Exception',
                'message_list_load_time_offline',
                'net_stat_yandex',
                'Cache increased after clear cache',
                'com.e.a.a',
                'java.lang',
            ]
        self.__supported_api_keys = supported_api_keys
        self.__supported_app_versions = supported_app_versions
        self.__ignored_events = ignored_events

    def __call__(self, records):

        users_pattern = re.compile('user[0-9]+')
        folder_pattern = re.compile('([0-9]+)folder')
        fragment_pattern = re.compile('(Fragment/([0-9]+))')


        for rec in records:
            if rec.APIKey not in self.__supported_api_keys:
                continue

            if self.__supported_app_versions and rec.AppVersionName not in self.__supported_app_versions:
                continue

            if rec.UUID is None:
                continue

            if rec.SessionType != 'SESSION_FOREGROUND':
                continue

            if rec.DeviceType != 'PHONE':
                continue

            if rec.EventType == 'EVENT_ALIVE':
                continue

            event_name = self.__create_composite_event_name(rec.EventName, rec.EventValue, rec.EventType)

            if self.__is_ignored(event_name):
                continue

            event_name = users_pattern.sub('user', event_name)
            event_name = folder_pattern.sub('folder', event_name)
            event_name = fragment_pattern.sub('Fragment/NUMBER', event_name)


            yield Record(uuid=rec.UUID, puid = self.__extract_puid(rec), event_name=event_name, version=rec.AppVersionName, timestamp=int(rec.EventTimestamp),
                         event_number=int(rec.EventNumber) if rec.EventNumber else 0)

    def __extract_puid(self, rec):
        if rec.ReportEnvironment_Keys and 'uid' in rec.ReportEnvironment_Keys:
            keys = json.loads(rec.ReportEnvironment_Keys)
            values = json.loads(rec.ReportEnvironment_Values)
            return values[keys.index("uid")]
        return '#undefined'

    def __is_ignored(self, event_name):
        if event_name is None:
            return True
        for ignored in self.__ignored_events:
            if ignored in event_name:
                return True
        return False

    def __create_composite_event_name(self, name, value, type):
        if type.decode("utf-8") == 'EVENT_ERROR':
            return None
        if not name:
            return type
        name = name.decode("utf-8")
        if 'dynametric' in name:
            try:
                json_value = json.loads(value)
            except:
                return name
            return self.__parse_mail_android_dynametric_event(name, json_value)
        return name

    def __parse_mail_android_dynametric_event(self, name, json_value):
        if not json_value:
            return name
        composite_name = name
        activity = json_value.get('activity')
        if activity:
            activity = activity.split('.')[-1]
            composite_name += '/' + activity
        fragment = json_value.get('fragment')
        if fragment:
            fragment = fragment.split('.')[-1]
            composite_name += '/' + fragment
        view_id = json_value.get('view_id')
        if view_id:
            view_id = view_id.split('/')[-1]
            composite_name += '/' + view_id
        tag = json_value.get('tag')
        if tag:
            composite_name += '/' + tag
        return composite_name


class EventNameDictionary:
    """
        Класс хранит маппинг имя события -> число
        0 зарезервировано под EVENT_START
        -1 - под мнимое событие завершения сессии
    """

    def __init__(self):
        self.__mapping = {}
        self.__imapping = {}
        self.__max_value = 1
        self.add('EVENT_START')

    def load(self, tbl):
        """
            Заменить текущий стейт стейтом из yt-таблицы
        """
        self.__clear()
        for rec in tbl.read():
            self.__mapping[rec['k']] = rec['v']
        if self.__mapping:
            self.__max_value = max(self.__mapping.values())

        assert self.add('EVENT_START') == 0

    def update_with(self, cluster, tbl):
        '''
            Добавить все ключи из указанной таблицы
            Старый стейт сбрасывается
        '''
        self.__clear()

        job = cluster.job()

        src = job.table(tbl).unique('event_name').put('$tmp/dict')
        job.run()
        for rec in src.read():
            self.add(rec.event_name)

    def add(self, element):
        """
        Добавить ключ в отображение
        :param element: ключ
        :return: числовой идентифиуатор
        """
        if element not in self.__mapping:
            self.__mapping[element] = self.__max_value
            self.__max_value += 1
        return self.__mapping[element]

    def __getitem__(self, item):
        return self.__mapping[item]

    def save(self, tbl):
        tbl.write([Record(k=k, v=v) for k, v in self.__mapping.iteritems()])

    def __clear(self):
        self.__mapping.clear()
        self.__max_value = 1
        self.__mapping['EVENT_START'] = 0
        self.__mapping['EVENT_END'] = -1

    def invert(self):
        for k, v in self.__mapping.iteritems():
            self.__imapping[v] = k

    def decode(self, l):
        return [self.__imapping[x] for x in l]


class Context:
    """
        Контекст совершения какого-либо действия
        По факту - обертка над массивом длины max_len
    """

    def __init__(self, max_len, state=None):
        # assert max_len > 0 or max_len == -1 and state
        if state is None:
            state = []
        self.__max_len = max_len
        self.__state = state

    def update(self, value):
        """
            Добавляет value к текущему стейту.
            Если длина стейта превышает допустимую, то обрезает стейт.
        """
        self.__state.append(value)
        while len(self.__state) > self.__max_len:
            self.__state.pop(0)
        return self.get()

    def get(self):
        return self.__state

    def to_str(self):
        return '|'.join([str(x) for x in self.__state])

    @staticmethod
    def from_str(s):
        return Context(-1, [int(x) for x in s.split('|')] if s else [])

    def state(self):
        return self.__state

    def __eq__(self, other):
        return self.__state == other.state()


class Edge:
    """
        Ребро в графе переходов (почти).
        Представляет собой тройку: контекст (что пользователь делал до этого), целевая вершина
        и суммарное кол-во переходов по данному ребру
    """

    def __init__(self, context, target, weight):
        self.__context = context
        self.__target = target
        self.__weight = weight
        self.__visited = False

    def visit(self):
        self.__visited = True

    def is_visited(self):
        return self.__visited

    def weight(self):
        return self.__weight

    def match(self, ctx, target):
        """
            Метод позволяет определить, разрешен ли переход из указанного контекста в указанный таргет по данному ребру
        """
        return self.__target == target and self.__context == ctx


class Node:
    """
        Вершина в графе переходов
        Имеет имя (event_name) и множество исходящих ребер.
        Каждое ребро ассоциировано со своим контекстом и целевой вершиной.
        По поводу canonized - по идее надо либо делать билдер, либо в какой-то момент канонизировать граф, что бы запретить его изменения.
        Но я это не доделал :(
    """

    def __init__(self, name, edges):
        self.__name = name
        self.__edges = edges
        self.__canonized = False

    def add_edge(self, edge):
        """
            Добавить ребро из текущей вершины
        """
        if self.__canonized:
            raise ValueError("can't modify canonized node")
        # todo check unique
        self.__edges.append(edge)

    def edges(self):
        return self.__edges

    def find_edge(self, ctx, name):
        """
            Найти исходяшее ребро с контекстом ctx и целевой вершиной target
        :return: ребро|None, факт посещения ребра
        """
        for edge in self.__edges:
            if edge.match(ctx, name):
                return edge, edge.is_visited()
        return None, False


class Graph:
    """
        Граф переходов, обертка над множеством вершин.
    """

    def __init__(self, nodes=None):
        if nodes is None:
            nodes = {}
        self.__nodes = nodes

    def score(self):
        total = 0.
        visited = 0.
        for node in self.__nodes.values():
            for edge in node.edges():
                total += edge.weight()
                if edge.is_visited():
                    visited += edge.weight()
        return visited / total, visited, total

    def add_edge(self, src, context, target, weigth):
        """
            Добавить ребро в граф
            weight - кол-во переходов по ребру
        """
        src = self.__get_or_create_node(name=src)
        self.__get_or_create_node(name=target)
        edge = Edge(context, target, weigth)
        src.add_edge(edge)

    def __get_or_create_node(self, name):
        if name not in self.__nodes:
            self.__nodes[name] = Node(name, [])
        return self.__nodes[name]

    def nodes(self):
        return self.__nodes.values()

    def __getitem__(self, item):
        return self.__nodes.get(item)


class GraphBuilder:
    def __init__(self, event_name_dict, session_len):
        self.__event_name_dict = event_name_dict
        self.__session_len = session_len

    def __do_split(self):
        event_name_dict = self.__event_name_dict

        def foo(groups):
            for key, records in groups:
                events = []
                for rec in records:
                    events.append((event_name_dict[rec.event_name], rec.timestamp, rec.event_number))

                events.sort(key=lambda x: (x[1], x[2]))

                actual_session = []

                for e in itertools.chain(events, [(0, 0, 0)]):
                    e = e[0]
                    if e == 0:
                        if actual_session:
                            yield Record(uuid=key.uuid, puid=key.puid, version=key.version, session=actual_session)
                            actual_session = []
                    actual_session.append(e)

        return foo

    def __unfold_session(self):
        session_len = self.__session_len
        event_name_dict = self.__event_name_dict

        def foo(records):
            # sys.stderr.write('fooing\n')
            cache = {}
            for rec in records:
                ctx = Context(session_len)
                session = rec.session
                for i in range(0, len(session)):
                    actual_event = session[i]
                    next_event = session[i + 1] if i + 1 < len(session) else event_name_dict['EVENT_END']
                    r = (ctx.to_str(), actual_event, next_event)
                    cache[r] = cache.get(r, 0) + 1

                    if len(cache) > 10000:
                        # sys.stderr.write('cleaning cache\n')
                        for r, c in cache.iteritems():
                            # sys.stderr.write('yielding: {}\n'.format(r))
                            yield Record(context=r[0], actual_event=r[1], next_event=r[2], c=c)
                        cache.clear()

                    ctx.update(actual_event)

            for r, c in cache.iteritems():
                # sys.stderr.write('yielding2: {}\n'.format(r))
                yield Record(context=r[0], actual_event=r[1], next_event=r[2], c=c)

        return foo

    def build(self, cluster, src, dst):
        job = cluster.job()

        job.table(src) \
            .groupby('uuid', 'puid', 'version') \
            .reduce(self.__do_split()) \
            .put('$home/sessions')
        job.run()

        job = cluster.job()
        r = job.table('$home/sessions') \
            .map(self.__unfold_session(), intensity='cpu') \
            .groupby('context', 'actual_event', 'next_event') \
            .aggregate(
            count=na.sum('c')
        ) \
            .put(dst)

        job.run()

        return self.load(job.table(r))

    def load(self, tbl):
        g = Graph()
        for rec in tbl.read():
            # todo: magic number alert
            if rec.count > 100:
                context = Context.from_str(rec.context)
                g.add_edge(rec.actual_event, context, rec.next_event, rec.count)
        return g


class Engine:
    def __init__(self, event_dictionary, graph, max_ctx_len):
        self.__event_dictionary = event_dictionary
        self.__graph = graph
        self.__max_ctx_len = max_ctx_len

    @staticmethod
    def test(graph, max_ctx_len, session):
        scores = []
        score = 0
        context = Context(max_ctx_len, state=[])
        visited_edges = set()
        for i in range(0, len(session)):
            e = session[i]
            next = session[i + 1] if i + 1 < len(session) else -1
            node = graph[e]
            if node is None:
                break
            edge, is_visited = node.find_edge(context, next)
            if edge is None:
                break
            if not is_visited and edge not in visited_edges:
                score += edge.weight()
                visited_edges.add(edge)
            context.update(e)
            scores.append(score)
        return score, scores

    def apply_session(self, session):
        context = Context(self.__max_ctx_len, state=[])
        for i in range(0, len(session)):
            e = session[i]
            next = session[i + 1] if i + 1 < len(session) else -1
            node = self.__graph[e]
            if node is None:
                break
            edge, is_visited = node.find_edge(context, next)
            if edge is None:
                break
            edge.visit()
            context.update(e)

    class SessionScorer:
        def __init__(self, max_ctx_len):
            self.__max_ctx_len = max_ctx_len
            self.__graph = Graph()
            self.__graph.add_edge('a', [1,2,3], 'b', 123)

        def __call__(self, records):
            sys.modules['__main__'].Graph = Graph
            sys.modules['__main__'].Node = Node
            sys.modules['__main__'].Edge = Edge
            sys.modules['__main__'].Context = Context
            sys.modules['__main__'].EventNameDictionary = EventNameDictionary

            score_cache = set()
            sess_cache = set()
            with open('graph') as f:
                self.__graph = pickle.load(f)

            for rec in records:
                if len(rec.session) > 100:
                    continue
                actual_session = []
                sess_str = ' '.join([str(x) for x in rec.session])
                if sess_str in sess_cache:
                    continue
                sess_cache.add(sess_str)
                score, scores = Engine.test(self.__graph, self.__max_ctx_len, rec.session)
                if score > 0 and score not in score_cache:
                    for i in range(0, len(rec.session)):
                        actual_session.append(rec.session[i])
                        if scores[i] / (score * 1.) > 0.95:
                            score = scores[i]
                            break
                    if len(actual_session) != rec.session:
                        sess_str = ' '.join([str(x) for x in actual_session])
                        if sess_str in sess_cache:
                            continue
                        sess_cache.add(sess_str)
                    # score_cache.add(score) # TODO ?!?!??!?!?!?!?
                    fake_coeff = max(1. - 0.01 * (len(actual_session) / 5), 0.5)
                    yield Record(score=score, session=actual_session, fake_score = -score * fake_coeff, l=len(actual_session), ol = len(rec.session))
            # yield Record(score=0, session=[1, 2, 3])

    def advice(self, cluster, top=100):
        advices = []
        sessions = None

        while len(advices) < top and self.__graph.score()[0] < 0.999:
            job = cluster.job()

            if sessions is None:
                sessions = job.table('$home/sessions')
            else:
                sessions = job.table('$home/unscored_sessions')

            scorer = Engine.SessionScorer(self.__max_ctx_len)
            with open('graph', 'w') as o:
                pickle.dump(self.__graph, o)

            sessions = sessions \
                .map(scorer, intensity='ultra_cpu', files=['graph'], memory_limit=2048)\
                .put('$home/unscored_sessions')

            job.run()

            job = cluster.job()
            job.table('$home/unscored_sessions').sort('fake_score').put('$home/scored_sessions')
            job.run()

            print('current score: {} ({} / {})'.format(*self.__graph.score()))

            sessions = cluster.job().table('$home/scored_sessions')
            records = []
            scores = []
            top_score = -1
            for rec in sessions.read():
                if len(records) >= 2500 or (scores and rec.score < scores[0] / 2):
                    break
                top_score = max(top_score, abs(rec.score))
                records.append(rec.session)
                scores.append(abs(rec.score))

            self.apply_session(records[0])
            advices.append((records[0], self.__graph.score()))

            for i in range(1, len(records)):
                cur_score, _ = Engine.test(self.__graph, self.__max_ctx_len, records[i])
                print('[info] score {} -> {}'.format(scores[i], cur_score))
                if cur_score > scores[i] * 0.75:
                    self.apply_session(records[i])
                    advices.append((records[i], self.__graph.score()))

            print('score after applying: {} ({} / {})'.format(*self.__graph.score()))
            print('advices: {}'.format(len(advices)))

            for adv, score in advices:
                print('{}'.format(self.__event_dictionary.decode(adv)))
        return advices

    def score(self):
        return self.__graph.score()


def process_metrika_log(cluster, api_keys, app_versions, start_date, end_date):
    metrika_processor = MetrikaLogProcessor(supported_api_keys=api_keys, supported_app_versions=app_versions)

    job = cluster.job().env(
        templates=dict(
            dates='{' + '{0}..{1}'.format(start_date, end_date) + '}'
        )
    )

    job.table('logs/metrika-mobile-log/1d/@dates') \
        .filter(nf.equals('APIKey', api_keys[0])) \
        .map(metrika_processor) \
        .put('$home/metrika_events', schema=dict(uuid=str, puid=str, event_name=str, version=str, timestamp=int, event_number=int))

    job.run()

def extract_test_data(cluster, test_uids, app_versions):
    job = cluster.job()
    job \
        .table('$home/sessions') \
        .filter(
            qf.one_of('puid', test_uids),
            # qf.one_of('version', app_versions),
        ) \
        .put('$home/test_data')
    job.run()

def create_event_dict(cluster):
    event_dictionary = EventNameDictionary()
    event_dictionary.update_with(cluster, '$home/metrika_events')
    event_dictionary.save(tbl=cluster.job().table('$home/event_dict'))
    return event_dictionary


def handle_options():
    parser = argparse.ArgumentParser()
    # parser.add_argument("--path", dest = "path")
    parser.add_argument("--home", dest="home", default='home/mail-logs/vkokarev/app')
    parser.add_argument("--version", dest="version", default='general')

    commands = parser.add_subparsers(title='sub-commands')

    prepare_parser = commands.add_parser('prepare')
    prepare_parser.add_argument("--start-date", dest="start_date", required=True)
    prepare_parser.add_argument("--end-date", dest="end_date", required=True)
    prepare_parser.add_argument('--api-keys', dest='api_keys', nargs='+', help='Allowed api keys', required=True)
    prepare_parser.add_argument('--app-versions', dest='app_versions', nargs='+', help='Allowed app versions', required=False, default=None)
    prepare_parser.add_argument("--session-size", dest="session_size", type=int, default=4)

    test_data_parser = commands.add_parser('extract-test-data')
    test_data_parser.add_argument('--app-versions', dest='app_versions', nargs='+', help='Allowed app versions', required=True)
    test_data_parser.add_argument('--test-uids', dest='test_uids', nargs='+', help='Test uids', required=True)
    test_data_parser.add_argument("--session-size", dest="session_size", type=int, default=4)

    analyze_parser = commands.add_parser('analyze')
    analyze_parser.add_argument("--session-size", dest="session_size", type=int, default=4)
    # parser.add_argument("--ignore-users", dest="ignore_users", action='store_true')

    # parser.add_argument("--path", dest = "path")
    # parser.add_argument("--avg-measure", dest="avg_measure", default='days')
    # parser.add_argument("--ttl", dest="ttl", type=int, default=700)
    # parser.add_argument("--ignore-users", dest="ignore_users", action='store_true')
    # parser.add_argument("--merge", dest="is_merge", action='store_true')
    # parser.add_argument("--token", dest = "token", required=True)
    return parser


def main():
    args = handle_options().parse_args()
    action = sys.argv[1]

    token = os.environ['YT_TOKEN']
    if token is None:
        raise Exception('export YT_TOKEN env variable')

    cluster = clusters.yt.Hahn(token=token, pool='sherlock').env(
        templates=dict(
            home='home/mail-logs/vkokarev/app',
            tmp='$home/tmp/$job_name'
        )
    )

    if action == 'prepare':
        # prepare --api-keys=14836 --start-date 2019-06-03 --end-date 2019-06-10
        if len(args.api_keys) > 1:
            raise Exception
        process_metrika_log(cluster, args.api_keys, args.app_versions, args.start_date, args.end_date)
        print('Creating dictionary')
        event_dictionary = create_event_dict(cluster)
        # event_dictionary = EventNameDictionary()
        # event_dictionary.load(cluster.job().table('$home/event_dict'))
        print('building')
        builder = GraphBuilder(event_dictionary, args.session_size)
        builder.build(cluster, '$home/metrika_events', '$home/graph')
        return
    elif action == 'extract-test-data':
        # extract-test-data --test-uids 681081409 684826209 683247944 409538438 457576940 410245296 649888775 663302301 727477879 640219014 640224402 646915928 652552175 --app-versions 4.26.1
        # if args.yt:
        extract_test_data(cluster, args.test_uids, args.app_versions)
        # else:
        #     raise Exception('file is not supported yet')
        return
    elif action == 'analyze':
        # analyze
        event_dictionary = EventNameDictionary()

        print('loading dict')
        event_dictionary.load(cluster.job().table('$home/event_dict'))
        event_dictionary.invert()

        print('loading graph')
        builder = GraphBuilder(event_dictionary, args.session_size)
        g = builder.load(cluster.job().table('$home/graph'))

        engine = Engine(event_dictionary, g, args.session_size)
        print('applying test sessions')
        for rec in cluster.read('$home/test_data'):
            engine.apply_session(rec.session)

        with open('ggg2', 'w') as o:
            pickle.dump(g, o)

        print('Current score is: {} ({} / {})'.format(*engine.score()))
        advices = engine.advice(cluster)
        print('{}'.format(advices))
        pass
    return


if __name__ == '__main__':
    main()
