# coding=utf-8
import argparse
import itertools
import json
import math
import os
import subprocess

import numpy as np
import yt.wrapper as yt
import utils



def module_filter(module):
    module_name = getattr(module, '__name__', '')
    if 'numpy' in module_name:
        return False
    if 'yt_yson_bindings' in module_name:
        return False
    if 'hashlib' in module_name:
        return False
    if 'hmac' in module_name:
        return False

    module_file = getattr(module, '__file__', '')
    if not module_file:
        return False
    if module_file.endswith('.so'):
        return False

    return True


yt.config["proxy"]["url"] = "hahn.yt.yandex.net"
yt.config["memory_limit"] = 6000000000
yt.config['pickling']['dynamic_libraries']['enable_auto_collection'] = True
yt.config['pickling']['module_filter'] = module_filter
yt.config["tabular_data_format"] = yt.YsonFormat()

yt.config["pool"] = 'sherlock'
yt.config["write_parallel"]["enable"] = True


def read_crashes(date, version_name, api_key):
    input_path = yt.TablePath('//logs/metrika-mobile-log/1d/{}'.format(date), sorted_by=['APIKey'], columns=[
        'APIKey',
        'EventID',
        'EventValueCrash',
        'OperatorName',
        'Manufacturer',
        'Model',
        'AppVersionName',
        'ConnectionType',
        'OSApiLevel',
        'EventType'
    ])
    output_path = '//home/mobilesearch/salavat/retrace/api_key_{}/version_{}/1d/{}'.format(api_key, version_name, date)
    if not yt.exists(output_path):
        yt.create('table', output_path, recursive=True, ignore_existing=True)
        yt.run_map(MetrikaCrashesSelectMapper(api_key, version_name), input_path, output_path)
    return output_path

def cartesian2(arrays):
    arrays = [np.asarray(a) for a in arrays]
    shape = (len(x) for x in arrays)

    ix = np.indices(shape, dtype=int)
    ix = ix.reshape(len(arrays), -1).T

    for n, arr in enumerate(arrays):
        ix[:, n] = arrays[n][ix[:, n]]

    return ix


class CalculateDistances(object):

    def __init__(self, c, o, processed_crashes_dict):
        self._c = c
        self._o = o
        self._crashes_dict = processed_crashes_dict

    def __call__(self, row):
        outer_event_id = row['eventID1']
        inner_event_id = row['eventID2']
        outer_crash = self._crashes_dict[outer_event_id]['Trace']
        inner_crash = self._crashes_dict[inner_event_id]['Trace']

        all_costs, matched_frames_idx, sig, similarity, M = self.calculate_similarity(outer_crash.splitlines(), inner_crash.splitlines())
        if sig != 0.0:
            distance = 1.0 - similarity/sig
        else:
            distance = 1.0
        if distance > 0.0:
            new_row = {
                'eventID1': outer_event_id,
                'eventID2': inner_event_id,
                'distance': distance,
                'similarity': similarity,
                'sig': sig,
                'outer_stack_size': len(outer_crash),
                'inner_stack_size': len(inner_crash),
                'matched_frames_idx': matched_frames_idx,
                'number_of_matched_frames': len(matched_frames_idx),
                'M': json.dumps(M),
                'stack_event1': outer_crash,
                'stack_event2':inner_crash,
                'all_costs': all_costs
            }
            yield new_row

    def calculate_similarity(self, outer_crash, inner_crash):
        outer_crash_len = len(outer_crash)
        inner_crash_len = len(inner_crash)
        # инициализируем массив расстояний
        M = list()
        for i in range(outer_crash_len):
            M.append(list())
            for j in range(inner_crash_len):
                M[i].append(0.0)
        # итерируемся по фреймам внешнего креша
        matched_frames_idx = list()
        # идея для векторизации используя numpy - подготавливаем все возможные пары фреймов
        # и вычисляем min_frame_distance и frame_offset, по ним вычисляем cost

        array1 = cartesian2([np.arange(outer_crash_len), np.arange(inner_crash_len)])
        all_min_distances = np.min(array1, 1)
        all_frame_offsets = np.abs(np.subtract(array1[:, 0], array1[:, 1]))
        all_costs = np.multiply(np.exp(-self._c * all_min_distances), np.exp(-self._o * all_frame_offsets))

        outer_crash_frame_index = 0
        while outer_crash_frame_index < outer_crash_len:
            inner_crash_frame_index = 0
            while inner_crash_frame_index < inner_crash_len:
                # итерируемся по фреймам внутреннего креша
                outer_crash_str = outer_crash[outer_crash_frame_index]
                inner_crash_str = inner_crash[inner_crash_frame_index]
                cost = 0.0
                if outer_crash_str == inner_crash_str:
                    matched_frames_idx.append((outer_crash_frame_index, inner_crash_frame_index))
                    # cost = all_costs[np.argmax(all_costs == [outer_crash_frame_index, inner_crash_frame_index])]
                    min_frame_distance = min(outer_crash_frame_index, inner_crash_frame_index)
                    frame_offset = abs(outer_crash_frame_index - inner_crash_frame_index)
                    cost = math.exp(-self._c * min_frame_distance) * math.exp(-self._o * frame_offset)
                if outer_crash_frame_index == 0 or inner_crash_frame_index == 0:
                    if outer_crash_frame_index == 0 and inner_crash_frame_index == 0:
                        M[0][0] = 0
                    elif outer_crash_frame_index == 0:
                        M[0][inner_crash_frame_index] = max(cost,M[outer_crash_frame_index][inner_crash_frame_index - 1])
                    elif inner_crash_frame_index == 0:
                        M[outer_crash_frame_index][0] = max(cost, M[outer_crash_frame_index - 1][inner_crash_frame_index])
                else:
                    M[outer_crash_frame_index][inner_crash_frame_index] = max(
                            M[outer_crash_frame_index - 1][inner_crash_frame_index - 1] + cost,
                            M[outer_crash_frame_index - 1][inner_crash_frame_index],
                            M[outer_crash_frame_index][inner_crash_frame_index - 1]

                    )
                inner_crash_frame_index += 1
            outer_crash_frame_index += 1
        sig = 0
        # for i in range(1, min(outer_crash_len, inner_crash_len)):
        #     sig += math.exp(-self._c * i)
        sig = np.sum(np.exp(-self._c * i) for i in range(1,min(outer_crash_len, inner_crash_len) + 1))
        similarity = M[-1][-1]
        return all_costs, matched_frames_idx, sig, similarity, M


class StacktraceCleaner(object):

    def __init__(self, stacktrace_trim_size, caused_by_number):
        self._trim_size = stacktrace_trim_size
        self._caused_by_num = caused_by_number

    def __call__(self, row):
        new_row = dict()
        stacktrace = utils.preprocess_stacktrace(row['Trace'], self._trim_size, self._caused_by_num, False)
        if len(stacktrace) > 0:
            if 'EventID' in row:
                new_row['EventID'] = row['EventID']
            if 'Trace' in row:
                new_row['Trace'] = stacktrace
            if 'OperatorName' in row:
                new_row['OperatorName'] = row['OperatorName']
            if 'Manufacturer' in row:
                new_row['Manufacturer'] = row['Manufacturer']
            if 'Model' in row:
                new_row['Model'] = row['Model']
            if 'AppVersionName' in row:
                new_row['AppVersionName'] = row['AppVersionName']
            if 'ConnectionType' in row:
                new_row['ConnectionType'] = row['ConnectionType']
            if 'OSApiLevel' in row:
                new_row['OSApiLevel'] = row['OSApiLevel']
            yield new_row


class GenerateIndexesPair(object):

    def __init__(self, crashes_dict):
        self._crashes_dict = crashes_dict

    def __call__(self, row):
        combinations = itertools.combinations(self._crashes_dict, 2)
        for combination in combinations:
            row = {
                'eventID1': combination[0],
                'eventID2': combination[1]
            }
            yield row



class MetrikaCrashesSelectMapper(object):

    def __init__(self, api_key, version_name):
        self._api_key = api_key
        self._version_name = version_name

    def __call__(self, row):
        if row['APIKey'] == str(self._api_key) and row['AppVersionName'] == str(self._version_name) and \
                row['EventType'] == 'EVENT_CRASH':
            new_row = dict()
            if 'EventID' in row:
                new_row['EventID'] = row['EventID']
            if 'EventValueCrash' in row:
                new_row['Trace'] = row['EventValueCrash']
            if 'OperatorName' in row:
                new_row['OperatorName'] = row['OperatorName']
            if 'Manufacturer' in row:
                new_row['Manufacturer'] = row['Manufacturer']
            if 'Model' in row:
                new_row['Model'] = row['Model']
            if 'AppVersionName' in row:
                new_row['AppVersionName'] = row['AppVersionName']
            if 'ConnectionType' in row:
                new_row['ConnectionType'] = row['ConnectionType']
            if 'OSApiLevel' in row:
                new_row['OSApiLevel'] = row['OSApiLevel']
            yield new_row



def main():
    parser = argparse.ArgumentParser(description='Try 1 of clustering crashes')
    parser.add_argument('--date', required=True,help='date')
    parser.add_argument('--version_name', required=True,help='version_name')
    parser.add_argument('--api_key', required=True,help='API key from metrika')
    parser.add_argument('--token', required=True, help='Yt Token')
    args = parser.parse_args()
    yt.config["token"] = args.token
    date = args.date
    version_name = args.version_name
    api_key = args.api_key
    # read crashes form Metrica to separate table
    path_to_crashes = read_crashes(date, version_name, api_key)
    print('reading crashes...')
    temp_crashes_list = yt.read_table(path_to_crashes, enable_read_parallel = True)
    # deobfuscate and write to YT
    print('deobfuscate crashes...')
    deobfuscated_crashes_table = deobfuscate_crashes(api_key, temp_crashes_list, date, version_name)
    # clean and preprocess crashes
    print('cleaning and preprocessing crashes...')
    processed_crashes_table = clean_and_process_crashes(api_key, date, deobfuscated_crashes_table, version_name)

    # generate indices pair to calculate distances
    print('reading cleaned and preprocessed crashes...')
    crashes_list = yt.read_table(processed_crashes_table, enable_read_parallel=True, format=yt.JsonFormat())
    new_crashes_dict = dict()
    for cur_crash in crashes_list:
        id = cur_crash['EventID']
        new_crashes_dict[id] = cur_crash

    print('writing index pairs...')
    indexes_pairs_table = write_indexes_pairs(api_key, date, version_name, new_crashes_dict)

    # calculate pairwise distances
    print('calculate distances...')

    distances_table = calculate_distances(api_key, date, indexes_pairs_table, version_name, new_crashes_dict)

    # do hierarchical clustering
    print('hierarchical clustering...')
    cluster_distance_threshold = 0.05
    temp_crashes_list = yt.read_table(deobfuscated_crashes_table, enable_read_parallel=True, format=yt.JsonFormat())
    new_temp_crashes_dict = dict()
    for cur_crash in temp_crashes_list:
        event_id = cur_crash['EventID']
        new_temp_crashes_dict[event_id] = cur_crash
    #
    perform_hier_clustering(distances_table, new_temp_crashes_dict, new_crashes_dict, cluster_distance_threshold, api_key, version_name, date)


def perform_hier_clustering(distances_table, crashes_dict, processed_crashes_dict, threshold, api_key, version_name, date):
    api_key_dir = 'api_key_{}'.format(api_key)
    if not os.path.exists(api_key_dir):
        os.mkdir(api_key_dir)
    version_name_dir = os.path.join(api_key_dir, 'version_{}'.format(version_name))
    if not os.path.exists(version_name_dir):
        os.mkdir(version_name_dir)
    date_path = os.path.join(version_name_dir, date)
    if not os.path.exists(date_path):
        os.mkdir(date_path)

    from scipy.cluster.hierarchy import complete, fcluster, dendrogram
    import scipy.spatial.distance as ssd
    # create square matrix with 0 on main diagonal and ones elsewhere
    crashes_event_id_to_index_dict = dict()
    index_to_crashe_event_id__dict = dict()
    i = 0
    for crash in crashes_dict:
        crashes_event_id_to_index_dict[crash] = i
        index_to_crashe_event_id__dict[i] = crash
        i += 1


    number_of_stacktraces = len(crashes_event_id_to_index_dict.keys())
    distances = np.ones((number_of_stacktraces, number_of_stacktraces), order='C')
    np.fill_diagonal(distances, 0.0)

    # lines = yt.read_table(yt.TablePath(distances_table, columns=['distance', 'eventID1', 'eventID2']), enable_read_parallel=True)
    # number_of_lines_read = 0
    # for line in lines:
    #     distance = line['distance']
    #     event_id1 = line['eventID1']
    #     event_id2 = line['eventID2']
    #     i = crashes_event_id_to_index_dict[event_id1]
    #     j = crashes_event_id_to_index_dict[event_id2]
    #     distances[i, j] = float(distance)
    #     distances[j, i] = float(distance)
    #     number_of_lines_read += 1
    #     if number_of_lines_read % 1000000 == 0:
    #         print('Read {} rows..'.format(number_of_lines_read))
    #
    # np.savetxt('distances.txt', distances)

    distances = np.loadtxt('distances.txt')

    distArray = ssd.squareform(distances).astype(np.float64)
    linkage = complete(distArray)
    print(linkage)
    clusters = fcluster(Z=linkage, t=threshold, criterion='distance')
    clusters = clusters.astype(int)
    histo, bin_edges = np.histogram(clusters, bins=np.unique(clusters))
    histo_with_bins = zip(bin_edges, histo)
    number_of_clusters = np.unique(clusters)[-1]
    print('for t = {} number of uniq clusters = {}'.format(threshold, number_of_clusters))

    root_clusters_table='//home/mobilesearch/salavat/retrace/api_key_{}/version_{}/clusters/1d/{}/{}'.format(api_key, version_name, date, threshold)
    histogram_table = os.path.join(root_clusters_table, 'clusters_histo')
    if not yt.exists(histogram_table):
        yt.create_table(histogram_table, recursive=True)

    histo_recs = list()
    for histo_item in histo_with_bins:
        histo_recs.append(
            {
                'bin': histo_item[0],
                'count': histo_item[1]
            }
        )
    histo_recs.sort(key = lambda x: x['count'], reverse=True)
    yt.write_table(yt.TablePath(histogram_table), histo_recs)

    for i in range(len(set(clusters))):
        current_cluster_table = os.path.join(root_clusters_table, 'cluster_{}'.format(i))
        if not yt.exists(current_cluster_table):
            current_cluster_table = os.path.join(root_clusters_table, 'cluster_{}'.format(i))
            yt.create('table', current_cluster_table, recursive=True, ignore_existing=True)

    clusters_map1 = dict()
    for i in range(len(set(clusters))):
        clusters_map1[i] = list()

    for i in range(len(clusters)):
        cluster_num = clusters[i]
        event_id = index_to_crashe_event_id__dict[i]
        event = crashes_dict[event_id]
        processed_crash = None
        for crash_idx in range(len(processed_crashes_dict)):
            if crash_idx == event_id:
                processed_crash = processed_crashes_dict[crash_idx]['Trace']
                break
        clusters_map1[cluster_num - 1].append(
            {
                'EventID': event_id,
                'AppVersionName': event['AppVersionName'],
                'ConnectionType': event['ConnectionType'],
                'Manufacturer': event['Manufacturer'],
                'Model': event['Model'],
                'OSApiLevel': event['OSApiLevel'],
                'OperatorName': event['OperatorName'],
                'Trace': event['Trace'],
                'Processed': processed_crash
            }
        )

    for i in range(len(clusters_map1)):
        current_cluster_table = os.path.join(root_clusters_table, 'cluster_{}'.format(i))
        cur_cluster_traces = clusters_map1[i]
        yt.write_table(yt.TablePath(current_cluster_table, append=True), cur_cluster_traces)



def calculate_distances(api_key, date, indexes_pairs_table, version_name, processed_crashes_dict):
    distances_table = '//home/mobilesearch/salavat/retrace/api_key_{}/version_{}/distances/1d/{}'.format(api_key,
                                                                                                         version_name,
                                                                                                         date)
    if not yt.exists(distances_table):
        yt.create('table', distances_table, recursive=True, ignore_existing=True)
        yt.run_map(CalculateDistances(1, 1, processed_crashes_dict), indexes_pairs_table, distances_table,job_count=10000)
    return distances_table


def write_indexes_pairs(api_key, date, version_name, crashes_dict):
    indexes_pairs_table = '//home/mobilesearch/salavat/retrace/api_key_{}/version_{}/indexes_pairs/1d/{}'.format(api_key, version_name, date)
    if not yt.exists(indexes_pairs_table):
        yt.create('table', indexes_pairs_table, recursive=True, ignore_existing=True)
        temp_rows = list()
        temp_rows.append({
            'tmp': 1
        })

        table = yt.create_temp_table('//home/mobilesearch/salavat/retrace')
        yt.write_table(table, temp_rows)

        yt.run_map(GenerateIndexesPair(crashes_dict), table, indexes_pairs_table)
    return indexes_pairs_table


def clean_and_process_crashes(api_key, date, deobfuscated_crashes_table, version_name):
    processed_crashes_table = '//home/mobilesearch/salavat/retrace/api_key_{}/version_{}/preprocessed/1d/{}'.format(api_key, version_name, date)
    if not yt.exists(processed_crashes_table):
        yt.create('table', processed_crashes_table, recursive=True, ignore_existing=True)
        # process crashes
        yt.run_map(StacktraceCleaner(10, -1), deobfuscated_crashes_table, processed_crashes_table)
    return processed_crashes_table


def deobfuscate_crashes(api_key, crashes_list, date, version_name):
    api_key_dir = 'api_key_{}'.format(api_key)
    if not os.path.exists(api_key_dir):
        os.mkdir(api_key_dir)
    version_name_dir = os.path.join(api_key_dir, 'version_{}'.format(version_name))
    if not os.path.exists(version_name_dir):
        os.mkdir(version_name_dir)
    date_path = os.path.join(version_name_dir, date)
    if not os.path.exists(date_path):
        os.mkdir(date_path)

    output_path = '//home/mobilesearch/salavat/retrace/api_key_{}/version_{}/deobfuscated/1d/{}'.format(api_key, version_name, date)
    if not yt.exists(output_path):
        yt.create('table', output_path, recursive=True, ignore_existing=True)
        new_records = list()
        for single_crash in crashes_list:
            crash_id = single_crash['EventID']
            trace = single_crash['Trace']
            filename = '{}.txt'.format(crash_id)
            full_file_path = os.path.join(date_path, filename)
            with open(full_file_path, 'w+') as crash_file:
                crash_file.write(trace)
            p = subprocess.Popen(['java', '-jar', 'retrace.jar', 'mapping_{}.txt'.format(version_name), full_file_path],
                                    stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            out, err = p.communicate()
            full_file_path_deobf = os.path.join(date_path, '{}.txt.deobfuscated'.format(crash_id))
            with open(full_file_path_deobf, 'w+') as deobf_file:
                deobf_file.write(out)

            with open(full_file_path_deobf, 'r+') as deobf_file:
                out = deobf_file.read()
                new_records.append(
                    {
                        'EventID': crash_id,
                        'Trace': out,
                        'OperatorName': single_crash['OperatorName'],
                        'Manufacturer': single_crash['Manufacturer'],
                        'Model': single_crash['Model'],
                        'AppVersionName': single_crash['AppVersionName'],
                        'ConnectionType': single_crash['ConnectionType'],
                        'OSApiLevel': single_crash['OSApiLevel']

                    }
                )
        print('writing deobfuscated table..')
        yt.write_table(output_path, new_records)
        print('deobfuscated table written..')
    return output_path


if __name__ == "__main__":
    main()
