# -*- coding: utf-8 -*-
from collections import defaultdict
from dateutil import parser
import io
import json
import logging
import socket
from zipfile import ZipFile

from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment
# from sandbox.projects.sandbox.resources import LXC_CONTAINER

from sandbox.projects.mssngr.rtc import util

vpn_prefixes = [
    "5.45.203.96/27",
    "5.45.203.128/25",
    "5.255.237.0/24",
    "5.255.244.0/23",
    "37.9.72.0/24",
    "37.9.83.128/25",
    "37.9.83.128/25",
    "37.9.84.0/24",
    "37.9.85.0/24",
    "37.9.89.0/24",
    "37.9.90.0/24",
    "37.9.104.0/24",
    "37.9.105.0/24",
    "37.9.110.0/24",
    "37.9.111.0/24",
    "37.9.114.0/24",
    "37.9.115.0/24",
    "37.9.116.0/23",
    "37.9.120.0/24",
    "37.9.121.0/24",
    "37.9.122.0/24",
    "37.9.123.0/24",
    "37.140.140.0/24",
    "37.140.141.128/25",
    "37.140.162.0/23",
    "37.140.169.0/24",
    "37.140.182.0/24",
    "37.140.183.0/24",
    "87.250.227.0/24",
    "87.250.246.128/25",
    "93.158.136.0/24",
    "93.158.154.0/24",
    "93.158.159.0/24",
    "95.108.188.0/23",
    "95.108.190.0/23",
    "95.108.193.0/24",
    "95.108.200.0/23",
    "95.108.202.0/23",
    "100.43.78.192/26",
    "141.8.130.0/23",
    "141.8.157.0/28",
    "141.8.184.0/22",
    "149.5.241.240/28",
    "178.154.160.0/22",
    "178.154.188.0/22",
    "213.180.195.0/24",
    "172.16.192.0/19",
]

vpn_nets = []


def vpn_enabled(candidates):
    import ipaddr

    ips = []
    for cv in candidates.values():
        cand = cv["candidate"]
        if not cand:
            continue
        parts = cand.split()
        ips.append(parts[4])
    for ip in ips:
        addr = ipaddr.IPAddress(ip)
        for net in vpn_nets:
            if addr in net:
                return True
        try:
            host, _ = socket.getnameinfo((ip, 0), 0)
            if host.endswith("vpn.dhcp.yndx.net"):
                return True
        except socket.gaierror as e:
            logging.warn("Can't resolve ip {}, exception {}".format(ip, e))

    return False


def call_duration(call):
    time_accepted = parser.parse(call["time_accepted"])
    time_ended = parser.parse(call["time_ended"])
    return str(time_ended - time_accepted)


inf = 10000000000000000


def list_diff(l):
    return [l[i+1] - l[i] for i in range(len(l) - 1)]


def list_multiply(l, v):
    return [1. * v * i for i in l]


def list_divide(l, v):
    return [1. / v * i for i in l]


def list_ratio(f, s):
    if len(f) != len(s):
        return []

    res = []
    for i in range(len(f)):
        if s[i] > 0:
            res.append(1. * f[i] / s[i])
        else:
            res.append(inf)
    return res


def list_part(f, s):
    if len(f) != len(s):
        return []

    return list_ratio(f, [f[i] + s[i] for i in range(len(f))])


class RtcLogSupport(sdk2.Task):
    def get_all_keys(self, dct, path):
        res = []
        for k in dct:
            new_path = path[:]
            new_path.append(k)
            if type(dct[k]) == dict:
                res.extend(self.get_all_keys(dct[k], new_path))
            else:
                res.append(new_path)
        return res

    def get_value_by_path(self, stat, path):
        dct = stat
        for key in path:
            if key in dct:
                dct = dct[key]
            else:
                return "bad"
        return dct

    def get_values_by_path(self, stats, path):
        values = [self.get_value_by_path(s, path) for s in stats]

        i = 0
        while i < len(values):
            try:
                float(str(values[i]))
                i += 1
            except ValueError:
                del(values[i])

        return values

    def extract_stats(self, stats):
        res = defaultdict(list)
        if len(stats) == 0:
            return res

        keys = self.get_all_keys(stats[-1], [])

        for key in keys:
            res["-".join(key)] = self.get_values_by_path(stats, key)
            res['ts'] = [s['ts'] for s in stats]

        logging.debug("All stats {}".format(res))

        return res

    def parse_mediator_logs(self, mediator_logs):
        call = {
            'caller_guid': 'unknown',
            'caller_login': 'unknown',
            'caller_platform': 'unknown',
            'caller_device_info': 'unknown',
            'callee_guid': 'unknown',
            'callee_login': 'unknown',
            'callee_platform': 'unknown',
            'callee_device_info': 'unknown',
            'iso_eventtime': 'unknown',
            'time': 'unknown',
            'description': [],
            'summary': '',
            'platforms': [],
            'environment': 'unknown',
            'caller_vpn_enabled': False,
            'callee_vpn_enabled': False,
            'call_duration': 'unknown',
            'call_created': 'unknown',
            'logins_found': False,
            'chat_id': 'unknown',
            'is_caller_relay': False,
            'caller_ip': 'unknown',
            'is_callee_relay': False,
            'callee_ip': 'unknown',
        }

        caller_stats = []
        callee_stats = []

        if (len(mediator_logs) == 0):
            call['environment'] = 'testing'
            self.set_info("Call from testing")
            return call

        for index in range(len(mediator_logs)):
            row = json.loads(mediator_logs[index])
            if index == 0:
                call['time'] = row['ts']

            if row['msg'] == 'db: Created call':
                logging.info("Find Created call row")

                call['caller_guid'] = row['call']['caller']['transport_id']
                call['callee_guid'] = row['call']['callee']['transport_id']
                call['iso_eventtime'] = row['iso_eventtime']
                call['chat_id'] = row['call']['chat_id']

            if row['msg'] == 'switch: MakeCall request' and row['device_info'] is not None:
                logging.info("Find MakeCall request row")

                call['caller_platform'] = row['device_info']['platform']
                call['platforms'].append(call['caller_platform'])
                call['caller_device_info'] = row['device_info']

            if row['msg'] == 'switch: AcceptCall request' and row['device_info'] is not None:
                logging.info("Find AcceptCall request row")

                call['callee_platform'] = row['device_info']['platform']
                call['platforms'].append(call['callee_platform'])
                call['callee_device_info'] = row['device_info']

            if row['msg'] == 'master: Delete finished call':
                logging.info("Find Delete finished call row")

                if 'SelfID' in row['call']['callee']['alias']:
                    call['callee_guid'] = row['call']['callee']['alias']['SelfID']

                call['caller_vpn_enabled'] = vpn_enabled(row['call']['caller']['candidates'])
                call['callee_vpn_enabled'] = vpn_enabled(row['call']['callee']['candidates'])
                call['call_duration'] = call_duration(row['call'])
                call['call_created'] = row['call']['time_created']

            if row["msg"] == "Received keepalive stats":
                row_stats = row['stats']['standard']
                row_stats['ts'] = parser.parse(row['ts'])
                if row['sender'] == call['caller_guid']:
                    caller_stats.append(row_stats)
                else:
                    callee_stats.append(row_stats)

        logging.info('Get staff nicknames')

        logins = util.get_staff_nicknames([call['caller_guid'], call['callee_guid']], util.get_messenger_auth())
        if logins[call['caller_guid']] == 'unknown' and logins[call['callee_guid']] == 'unknown':
            call['environment'] = 'mssngr_external'

            logging.info('Get yandex nicknames')

            api_helper = util.MetaApiHelper(self.Parameters.env)
            logins = api_helper.get_yandex_nicknames([call['caller_guid'], call['callee_guid']])
        else:
            call['environment'] = 'mssngr_internal'

        call['caller_login'] = logins[call['caller_guid']]
        call['callee_login'] = logins[call['callee_guid']]
        if (call['caller_login'] != 'unknown') & (call['callee_login'] != 'unknown'):
            call['logins_found'] = True

        self.set_info("Call from {}".format(call['environment']))
        return call, self.extract_stats(caller_stats), self.extract_stats(callee_stats)

    def create_call_meta_info_resource(self, call):
        resource = util.RtcCallMetaInfoResource(self, 'call_info', 'call_info.txt')
        resource_data = sdk2.ResourceData(resource)
        resource_data.path.write_bytes(json.dumps(call))
        resource_data.ready()
        return resource.id

    def create_plot(self, time, value, plot_name):
        import matplotlib.pyplot as plt

        if len(value) == 0 or len(time) != len(value):
            logging.error("can not create plot '{}', x: {}, y: {}".format(plot_name, time, value))
            return {}

        x = []
        y = []
        for i in range(len(time)):
            if value[i] != inf:
                x.append(time[i])
                y.append(value[i])

        plt.plot(x, y)

        plt.title(plot_name)
        plt.xlabel('time')
        return self.save_plot(plt, plot_name + ".png")

    def save_plot(self, plt, name):
        plt.xticks(rotation=15)
        ax = plt.gca()
        import matplotlib.dates as md
        xfmt = md.DateFormatter('%H:%M:%S')
        ax.xaxis.set_major_formatter(xfmt)
        plt.grid(True)
        buf = io.BytesIO()
        plt.savefig(buf, format="png")
        buf.seek(0)
        data = buf.read()
        buf.close()
        plt.close()
        return {'name': name, 'data': data}

    def create_call_plots_resource(self, caller_stats, callee_stats):
        import matplotlib
        matplotlib.use('agg')

        plots = [

        ]
        for stats, name in ((caller_stats, 'caller'), (callee_stats, 'callee')):
            plots.extend([
                # inbound_audio_stream
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_audio_stream-packets_received'],
                    name + '_inbound_audio_stream_packets_received'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(list_diff(stats['inbound_rtp_audio_stream-packets_received']), 10),
                    name + '_inbound_audio_stream_packets_received_per_sec'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_audio_stream-packets_lost'],
                    name + '_inbound_audio_stream_packets_lost'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(list_diff(stats['inbound_rtp_audio_stream-packets_lost']), 10),
                    name + '_inbound_audio_stream_packets_lost_per_sec'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_multiply(
                        list_part(
                            list_diff(stats['inbound_rtp_audio_stream-packets_lost']),
                            list_diff(stats['inbound_rtp_audio_stream-packets_received']),
                        ),
                        100),
                    name + '_inbound_audio_stream_packets_lost_percent'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_audio_stream-fraction_lost'],
                    name + '_inbound_audio_stream_fraction_lost'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_audio_stream-jitter'],
                    name + '_inbound_audio_stream_jitter'
                ),

                # inbound_audio_track
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_audio_stream-audio_track-audio_level'],
                    name + '_inbound_audio_track_audio_level'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_audio_stream-audio_track-total_audio_energy'],
                    name + '_inbound_audio_track_total_audio_energy'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_audio_stream-audio_track-total_samples_duration'],
                    name + '_inbound_audio_track_total_samples_duration'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_audio_stream-audio_track-total_samples_received'],
                    name + '_inbound_audio_track_total_samples_received'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_audio_stream-audio_track-concealed_samples'],
                    name + '_inbound_audio_track_concealed_samples'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_audio_stream-audio_track-concealment_events'],
                    name + '_inbound_audio_track_concealment_events'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_multiply(
                        list_ratio(
                            list_diff(stats['inbound_rtp_audio_stream-audio_track-concealed_samples']),
                            list_diff(stats['inbound_rtp_audio_stream-audio_track-total_samples_received'])
                        ),
                        100),
                    name + '_inbound_audio_track_concealed_samples_percent'
                ),

                # inbound_video_stream
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_video_stream-nack_count'],
                    name + '_inbound_video_stream_nack_count'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_video_stream-packets_received'],
                    name + '_inbound_video_stream_packets_received'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(list_diff(stats['inbound_rtp_video_stream-packets_received']), 10),
                    name + '_inbound_video_stream_packets_received_per_sec'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_video_stream-bytes_received'],
                    name + '_inbound_video_stream_bytes_received'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(list_diff(stats['inbound_rtp_video_stream-bytes_received']), 10 * 1000),
                    name + '_inbound_video_stream_bytes_received_in_bits_per_sec'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_video_stream-frames_decoded'],
                    name + '_inbound_video_stream_frames_decoded'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(list_diff(stats['inbound_rtp_video_stream-frames_decoded']), 10),
                    name + '_inbound_video_stream_frames_decoded_per_sec'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_multiply(
                        list_part(
                            list_diff(stats['inbound_rtp_video_stream-packets_lost']),
                            list_diff(stats['inbound_rtp_video_stream-packets_received']),
                        ),
                        100),
                    name + '_inbound_video_stream_packets_lost_percent'
                ),

                # inbound_video_track
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_video_stream-video_track-frame_width'],
                    name + '_inbound_video_track_frame_width'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_video_stream-video_track-frame_height'],
                    name + '_inbound_video_track_frame_height'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_video_stream-video_track-frames_dropped'],
                    name + '_inbound_video_track_frames_dropped'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_video_stream-video_track-frames_received'],
                    name + '_inbound_video_track_frames_received'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_multiply(
                        list_ratio(
                            list_diff(stats['inbound_rtp_video_stream-video_track-frames_dropped']),
                            list_diff(stats['inbound_rtp_video_stream-video_track-frames_received']),
                        ),
                        100),
                    name + '_inbound_video_track_frames_dropped_percent'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['inbound_rtp_video_stream-video_track-partial_frames_lost'],
                    name + '_inbound_video_track_partial_frames_lost'
                ),

                # outbound_audio_stream
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_audio_stream-packets_sent'],
                    name + '_outbound_audio_stream_packets_sent'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(list_diff(stats['outbound_rtp_audio_stream-packets_sent']), 10),
                    name + '_outbound_audio_stream_packets_sent_per_sec'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_audio_stream-bytes_sent'],
                    name + '_outbound_audio_stream_bytes_sent'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(list_diff(stats['outbound_rtp_audio_stream-bytes_sent']), 10 * 1000),
                    name + '_outbound_audio_stream_bytes_sent_in_bits_per_sec'
                ),

                # outbound_audio_track
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_audio_stream-audio_track-audio_level'],
                    name + '_outbound_audio_track_audio_level'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_audio_stream-audio_track-total_audio_energy'],
                    name + '_outbound_audio_track_total_audio_energy'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_audio_stream-audio_track-total_samples_duration'],
                    name + '_outbound_audio_track_total_samples_duration'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_audio_stream-audio_track-total_samples_received'],
                    name + '_outbound_audio_track_total_samples_received'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_audio_stream-audio_track-concealed_samples'],
                    name + '_outbound_audio_track_concealed_samples'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_audio_stream-audio_track-concealment_events'],
                    name + '_outbound_audio_track_concealment_events'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_multiply(
                        list_ratio(
                            list_diff(stats['outbound_rtp_audio_stream-audio_track-concealed_samples']),
                            list_diff(stats['outbound_rtp_audio_stream-audio_track-total_samples_received'])
                        ),
                        100
                    ),
                    name + '_outbound_audio_track_concealed_samples_percent'
                ),

                # outbound_video_stream
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_video_stream-nack_count'],
                    name + '_outbound_video_stream_nack_count'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_video_stream-packets_sent'],
                    name + '_outbound_video_stream_packets_sent'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(list_diff(stats['outbound_rtp_video_stream-packets_sent']), 10),
                    name + '_outbound_video_stream_packets_sent_per_sec'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_video_stream-bytes_sent'],
                    name + '_outbound_video_stream_bytes_sent'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(list_diff(stats['outbound_rtp_video_stream-bytes_sent']), 10 * 1000),
                    name + '_outbound_video_stream_bytes_sent_in_bits_per_sec'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_video_stream-frames_encoded'],
                    name + '_outbound_video_stream_frames_encoded'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(list_diff(stats['outbound_rtp_video_stream-frames_encoded']), 10),
                    name + '_outbound_video_stream_frames_encoded_per_sec'
                ),

                # outbound_video_track
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_video_stream-video_track-frame_width'],
                    name + '_outbound_video_track_frame_width'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_video_stream-video_track-frame_height'],
                    name + '_outbound_video_track_frame_height'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_video_stream-video_track-frames_sent'],
                    name + '_outbound_video_track_frames_sent'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['outbound_rtp_video_stream-video_track-partial_frames_lost'],
                    name + '_outbound_video_track_partial_frames_lost'
                ),

                # ice_candidate_pair
                self.create_plot(
                    stats['ts'],
                    stats['transport-bytes_sent'],
                    name + '_ice_candidate_pair_bytes_sent'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(list_diff(stats['transport-bytes_sent']), 1000 * 10),
                    name + '_ice_candidate_pair_bytes_sent_in_bits_rep_sec'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['transport-bytes_received'],
                    name + '_ice_candidate_pair_bytes_received'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(list_diff(stats['transport-bytes_received']), 1000 * 10),
                    name + '_ice_candidate_pair_bytes_received_in_bits_rep_sec'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['transport-selected_candidate_pair-available_incoming_bitrate'],
                    name + '_ice_candidate_pair_available_incoming_bitrate'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['transport-selected_candidate_pair-available_outgoing_bitrate'],
                    name + '_ice_candidate_pair_available_outgoing_bitrate'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['transport-selected_candidate_pair-total_round_trip_time'],
                    name + '_ice_candidate_pair_total_round_trip_time'
                ),
                self.create_plot(
                    stats['ts'][1:],
                    list_divide(
                        list_ratio(
                            list_diff(stats['transport-selected_candidate_pair-total_round_trip_time']),
                            list_diff(stats['transport-selected_candidate_pair-responses_received']),
                        ),
                        10 * 1000
                    ),
                    name + '_ice_candidate_pair_total_rtt_by_responses_received_in_ms'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['transport-selected_candidate_pair-current_round_trip_time'],
                    name + '_ice_candidate_pair_current_round_trip_time'
                ),
                self.create_plot(
                    stats['ts'],
                    stats['transport-selected_candidate_pair-responses_received'],
                    name + '_ice_candidate_pair_responses_received'
                )
            ])

        plots = [p for p in plots if p != {}]

        util.create_resource(self, 'plots', plots, util.RtcPlotsResource)

    class Requirements(sdk2.Task.Requirements):
        environments = [
            PipEnvironment("ipaddr"),
            PipEnvironment("matplotlib", "1.5.1", use_wheel=True),
        ]

    class Parameters(sdk2.Task.Parameters):
        with sdk2.parameters.String("Environment") as env:
            env.values.prod = env.Value('Production', default=True)
            env.values.alpha = 'Alpha'

        mediator_log_resource = sdk2.parameters.Resource(
            'Mediator logs zip',
            resource_type=util.RtcLogMediatorResource,
            required=True
        )

        # container = sdk2.parameters.Resource(
        #     'LXC container for task',
        #     resource_type=LXC_CONTAINER,
        #     required=False,
        #     default_value=111111,
        # )

    def on_execute(self):
        import ipaddr

        global vpn_nets
        vpn_nets += [ipaddr.IPNetwork(net) for net in vpn_prefixes]

        logs_path = str(sdk2.ResourceData(self.Parameters.mediator_log_resource).path)
        zf = ZipFile(logs_path, "r")
        zf.extractall(str(self.path()))
        zf.close()
        log_file_path = str(self.path('mediator_logs.txt'))

        with open(log_file_path) as f:
            mediator_log = f.readlines()
            call, caller_stats, callee_stats = self.parse_mediator_logs(mediator_log)
            self.create_call_plots_resource(caller_stats, callee_stats)
            self.create_call_meta_info_resource(call)
