# -*- coding: utf-8 -*-

import subprocess
import datetime
import logging
import os
import pprint
import re
import socket
import sys
import time
from collections import defaultdict

import dateutil.parser
import luigi
import splunklib.client
import splunklib.results as results
import yt.wrapper as yt

from lib.luigi import base_luigi_task
from lib.luigi import yt_luigi
from rtcconf import config
from utils import mr_utils as mr
from utils import utils
from utils.yql_utils import run_yql

from data_imports.import_logs.app_metrica_day import ImportAppMetrikaDayTask
from data_imports.import_logs.watch_log.graph_watch_log import ImportWatchLogDayTask

logger = logging.getLogger()


def find_raw_field(field, raw, default=None):
    m = re.search(r'%s="(.*?)"' % field, raw)
    if m:
        return m.group(1)
    else:
        return default


def get_ip(ip_raw):
    if ip_raw:
        if ':' in ip_raw:
            ip_type = 'IPv6'
        else:
            ip_type = 'IPv4'
    else:
        ip_type = 'None'
    return ip_raw, ip_type


def splunk_time_to_unixtime(datetime_str):
    try:
        tzinfos = {"MSK": 10800}
        dt = dateutil.parser.parse(datetime_str, tzinfos=tzinfos)
        return str(int(time.mktime(dt.timetuple())))
    except:
        return None


def read_password(pass_path):
    if not pass_path:
        raise Exception("RADIUS_SPLUNK_PASS_PATH is not set")

    with open(pass_path) as f:
        password_line = f.readline()
        if password_line:
            return password_line.strip()
        else:
            raise Exception("%s is empty" % pass_path)


class SplunkLookup(object):
    def __init__(self, host, port, username, password):
        self.splunk = None
        try:
            logger.info('Trying to establish a connection to %s:%s ...', host, port)
            self.splunk = splunklib.client.connect(host=host, port=port, username=username, password=password)
        except socket.error as err:
            logger.error('Splunk server is not available. ' + str(err))

    def search_users(self, date):
        if not self.splunk:
            return

        # search for a full single day
        time_range = to_splunk_time_range_predicate(date)
        query = 'search {} sourcetype="firewall_logs" login!=""'.format(time_range)
        logger.info('Splunk search query: ' + query)

        # search method can only retrieve 50k records, thus need to use export
        job = self.splunk.jobs.export(query, search_mode='normal')

        processed = 0
        succeed = 0
        fail_reasons = defaultdict(int)
        ip_types = defaultdict(int)
        by_type = defaultdict(lambda: defaultdict(int))
        start_time = datetime.datetime.now()

        for idx, rec in enumerate(results.ResultsReader(job)):
            processed = idx

            if idx % 1000 == 0:
                seconds_from_start = (datetime.datetime.now() - start_time).total_seconds()
                logger.info("Processed %d radius entries in %s",
                            idx, str(datetime.timedelta(seconds=seconds_from_start)))

            if isinstance(rec, dict):
                raw = rec['_raw']
                if 'assigned_ip' in raw and 'login' in raw:

                    ip_raw = find_raw_field('assigned_ip', raw)
                    ip, ip_type = get_ip(ip_raw)
                    ip_types[ip_type] += 1

                    login = find_raw_field('login', raw)
                    if not login or login == 'UNDEF':
                        fail_reasons['undefined login'] += 1
                        login = None

                    ts = find_raw_field('timestamp', raw)
                    if not ts and rec.get("_time"):
                        ts = splunk_time_to_unixtime(rec['_time'])
                    if not ts:
                        fail_reasons['no ts'] += 1

                    event = find_raw_field('event', raw, '')
                    conn_type = find_raw_field('type', raw, '')

                    by_type[conn_type][event] += 1

                    if ip and login and ts:
                        succeed += 1
                        yield {'ip': ip, 'timestamp': ts, 'login': login,
                               'conn_type': conn_type, 'event': event,
                               'ip_type': ip_type, 'rec_type': 'radius'}

                else:
                    fail_reasons['bad format'] += 1

        logger.info("==== STATS ====")
        logger.info("Processed: %d", processed)
        logger.info("Processed by type:\n%s", pprint.pformat(utils.default_to_regular(by_type)))
        logger.info("Processed by ip type:\n%s", pprint.pformat(utils.default_to_regular(ip_types)))
        logger.info("Succeed: %d", succeed)
        logger.info("Failed:\n%s", pprint.pformat(utils.default_to_regular(fail_reasons)))


def upload_log_to_yt(local_file, yt_table):
    f = open(local_file)
    yt.write_table(yt_table, f, format="dsv", raw=True)
    yt.run_map(convert_timestamp, yt_table, yt_table)


def convert_timestamp(rec):
    rec['timestamp'] = long(rec['timestamp'])
    yield rec


def distinct_ip(key, recs):
    yield {'ip': key['ip']}


def to_splunk_time_range_predicate(date_str):
    date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
    from_dt = "{}/{}/{}:00:00:00".format(date.month, date.day, date.year)
    to_dt = "{}/{}/{}:23:59:59".format(date.month, date.day, date.year)
    return "earliest={} latest={}".format(from_dt, to_dt)


def calculate_ips(yt_folder, date, store_days):
    radius_log_table = mr.get_date_table(yt_folder, date, 'radius_log')
    ip_table = mr.get_date_table(yt_folder, date, 'ips')

    yt.run_map_reduce(None, distinct_ip, radius_log_table, ip_table, reduce_by='ip')
    yt.run_sort(ip_table, ip_table, sort_by=['ip', 'timestamp'])

    # distinct all ips for previous days
    ips_tables_range = mr.get_existing_date_tables(yt_folder, 'ips', store_days)
    all_ips_table = mr.get_date_table(yt_folder, date, 'all_radius_ips')

    yt.run_reduce(distinct_ip, ips_tables_range, all_ips_table, reduce_by='ip')
    yt.run_sort(all_ips_table, sort_by=['ip', 'timestamp'])


def get_radius_ips(date):
    ips_table = mr.get_date_table(config.RADIUS_LOG_YT_FOLDER, date, 'all_radius_ips')
    if not yt.exists(ips_table):
        return set()

    radius_ips_recs = yt.read_table(ips_table, raw=False)
    radius_ips = map(lambda rec: rec['ip'], radius_ips_recs)
    return set(radius_ips)


def has_radius_ips_table(date):
    """ Check is ips_table has """
    ips_table = mr.get_date_table(config.RADIUS_LOG_YT_FOLDER, date, 'all_radius_ips')
    return yt.exists(ips_table)


def main(date, local_folder, yt_folder):
    splunk_url = config.RADIUS_SPLUNK_URL
    splunk_port = str(config.RADIUS_SPLUNK_PORT)
    splunk_user = config.RADIUS_SPLUNK_USR
    splunk_pass = read_password(config.RADIUS_SPLUNK_PASS_PATH)
    s = SplunkLookup(splunk_url, splunk_port, splunk_user, splunk_pass)

    if not os.path.exists(os.path.dirname(local_folder)):
        os.makedirs(os.path.dirname(local_folder))

    ips = set()
    logins = set()
    min_ts = 0
    max_ts = 0

    with open(local_folder + date, 'w') as fw:
        for rec in s.search_users(date):
            ips.add(rec['ip'])
            logins.add(rec['login'])
            ts = long(rec['timestamp'])
            if ts < min_ts or min_ts == 0:
                min_ts = ts
            if ts > max_ts:
                max_ts = ts
            line = '\t'.join(k + '=' + v for k, v in rec.iteritems()) + '\n'
            fw.write(line)

    logger.info("The number of processed unique IPs is %d", len(ips))
    logger.info("The number of processed unique logins is %d", len(logins))

    if len(ips) > 0:
        from_dt = utils.ts_to_datetime_str(min_ts)
        to_dt = utils.ts_to_datetime_str(max_ts)
        logger.info("Processed entries from %s to %s", from_dt, to_dt)

        logger.info("Uploading result to YT...")
        mr.mkdir(yt_folder + date)

        upload_log_to_yt(local_folder + date, mr.get_date_table(yt_folder, date, 'radius_log'))
        calculate_ips(yt_folder, date, int(config.STORE_DAYS))
    else:
        raise Exception('Radius data is not available')


class ImportRadiusSplunkLog(base_luigi_task.BaseTask):
    date = luigi.Parameter()
    priority = 1

    def run(self):
        main(self.date, config.RADIUS_LOG_LOCAL_FOLDER, config.RADIUS_LOG_YT_FOLDER)

    def output(self):
        out_f = config.RADIUS_LOG_YT_FOLDER + self.date + '/'
        return [yt_luigi.YtTarget(out_f + 'radius_log'),
                yt_luigi.YtTarget(out_f + 'all_radius_ips')]


class FilterByRadius(yt_luigi.BaseYtTask):

    date = luigi.Parameter()

    def requires(self):
        return [
            ImportRadiusSplunkLog(self.date),
            ImportAppMetrikaDayTask(date=self.date, run_date=self.date),
            ImportWatchLogDayTask(date=self.date, run_date=self.date)
        ]

    def run(self):
        run_yql('RadiusFilter', dict(date=self.date), {
            'GRAPH_YT_OUTPUT_FOLDER': config.YT_OUTPUT_FOLDER,
            'BSWATCH_LOG_DIR': config.LOG_FOLDERS['bs_watch'],
            'RADIUS_LOG_YT_FOLDER': config.RADIUS_LOG_YT_FOLDER,
        })

    def output_folders(self):
        return {'yt_output': config.YT_OUTPUT_FOLDER + self.date, }

    def output(self):
        return map(yt_luigi.YtTarget, (
            self.out_f('yt_output') + '/raw_links/watch_log_filtered_by_radius',
            self.out_f('yt_output') + '/mobile/mmetrika_log_filtered_by_radius',
        ))


if __name__ == '__main__':
    yt.config.set_proxy(config.MR_SERVER)

    date = sys.argv[1]
    logger.info('Running splunk Radius export at date %s', date)

    luigi.build([ImportRadiusSplunkLog(date)], workers=1, scheduler_port=8083)
