from __future__ import print_function
import datetime
import io
import logging
import os
import pprint
import re
import socket
import time
from collections import defaultdict, Counter

import dateutil.parser
import splunklib.client
import splunklib.results as results

import tvmauth

logger = logging.getLogger()


def get_ip(ip_raw):
    if ip_raw:
        if ":" in ip_raw:
            ip_type = "IPv6"
        else:
            ip_type = "IPv4"
    else:
        ip_type = "None"
    return ip_raw, ip_type


def default_to_regular(d):
    if isinstance(d, defaultdict) or isinstance(d, Counter):
        d = {str(k): default_to_regular(v) for k, v in d.iteritems()}
    if isinstance(d, set):
        return list(d)
    return d


def ts_to_datetime_str(ts):
    return datetime.datetime.fromtimestamp(ts).strftime("%Y-%m-%d %H:%M:%S")


def splunk_time_to_unixtime(datetime_str):
    try:
        tzinfos = {"MSK": 10800}
        dt = dateutil.parser.parse(datetime_str, tzinfos=tzinfos)
        return str(int(time.mktime(dt.timetuple())))
    except:
        return None


def find_raw_field(field, raw, default=None):
    m = re.search(r'%s="(.*?)"' % field, raw)
    if m:
        return m.group(1)
    else:
        return default


def to_splunk_time_range_predicate(date_str):
    date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
    from_dt = "{}/{}/{}:00:00:00".format(date.month, date.day, date.year)
    to_dt = "{}/{}/{}:23:59:59".format(date.month, date.day, date.year)
    return "earliest={} latest={}".format(from_dt, to_dt)


class ServiceWithTvm(splunklib.client.Service):
    def __init__(self, tvm_client, **kwargs):
        super(ServiceWithTvm, self).__init__(**kwargs)
        self.tvm_client = tvm_client
        self.saved_post = self.http.post
        self.http.post = self.hacked_post

    def get_ya_service_ticket(self):
        ticket = self.tvm_client.get_service_ticket_for("radius")
        return ("X-Ya-Service-Ticket", str(ticket))

    @property
    def _auth_headers(self):
        headers = super(ServiceWithTvm, self)._auth_headers
        headers.append(self.get_ya_service_ticket())
        return headers

    def hacked_post(self, url, headers=None, **kwargs):
        if not headers:
            headers = [self.get_ya_service_ticket()]

        return self.saved_post(url, headers=headers, **kwargs)


class ResponseReaderWrapper(io.RawIOBase):
    """ https://answers.splunk.com/answers/114045/python-sdk-results-resultsreader-extremely-slow.html """

    def __init__(self, responseReader):
        self.responseReader = responseReader

    def readable(self):
        return True

    def close(self):
        self.responseReader.close()

    def read(self, n):
        return self.responseReader.read(n)

    def readinto(self, b):
        sz = len(b)
        data = self.responseReader.read(sz)
        for idx, ch in enumerate(data):
            b[idx] = ch

        return len(data)


class SplunkLookup(object):
    def __init__(self, host, port, username, password, tvm_config):
        self.splunk = None
        try:
            logger.info("Trying to establish a connection to %s:%s ...", host, port)

            tvm_client = tvmauth.TvmClient(
                tvmauth.TvmApiClientSettings(
                    self_tvm_id=tvm_config.client_id,
                    self_secret=tvm_config.secret,
                    dsts={"radius": tvm_config.server_id},
                )
            )

            self.splunk = ServiceWithTvm(tvm_client, host=host, port=port, username=username, password=password)
            self.splunk.login()

        except socket.error as err:
            logger.error("Splunk server is not available. " + str(err))

    def search_users(self, date):
        if not self.splunk:
            return

        # search for a full single day
        time_range = to_splunk_time_range_predicate(date)
        query = 'search {} sourcetype="firewall_logs" login!=""'.format(time_range)
        logger.info("Splunk search query: " + query)

        # search method can only retrieve 50k records, thus need to use export
        job = self.splunk.jobs.export(query, search_mode="normal")

        processed = 0
        succeed = 0
        fail_reasons = defaultdict(int)
        ip_types = defaultdict(int)
        by_type = defaultdict(lambda: defaultdict(int))
        start_time = datetime.datetime.now()

        bufferedReader = results.ResultsReader(io.BufferedReader(ResponseReaderWrapper(job)))
        for idx, rec in enumerate(bufferedReader):
            processed = idx

            if idx % 1000 == 0:
                seconds_from_start = (datetime.datetime.now() - start_time).total_seconds()
                logger.info(
                    "Processed %d radius entries in %s", idx, str(datetime.timedelta(seconds=seconds_from_start))
                )

            if isinstance(rec, dict):
                raw = rec["_raw"]

                ip_raw = find_raw_field("assigned_ip", raw)
                ip, ip_type = get_ip(ip_raw)
                ip_types[ip_type] += 1
                if not ip:
                    fail_reasons["no ip"] += 1

                login = find_raw_field("login", raw)
                if not login or login == "UNDEF":
                    fail_reasons["undefined login"] += 1
                    login = None

                ts = find_raw_field("timestamp", raw)
                if not ts and rec.get("_time"):
                    ts = splunk_time_to_unixtime(rec["_time"])
                if not ts:
                    fail_reasons["no ts"] += 1

                event = find_raw_field("event", raw, "")
                conn_type = find_raw_field("type", raw, "")

                by_type[conn_type][event] += 1

                if ip and login and ts:
                    succeed += 1

                yield {
                    "ip": ip,
                    "timestamp": ts,
                    "login": login,
                    "conn_type": conn_type,
                    "event": event,
                    "ip_type": ip_type,
                    "rec_type": "radius",
                    "rec": rec,
                }

        logger.info("==== STATS ====")
        logger.info("Processed: %d", processed)
        logger.info("Processed by type:\n%s", pprint.pformat(default_to_regular(by_type)))
        logger.info("Processed by ip type:\n%s", pprint.pformat(default_to_regular(ip_types)))
        logger.info("Succeed: %d", succeed)
        logger.info("Failed:\n%s", pprint.pformat(default_to_regular(fail_reasons)))


def import_radius_log_to_local_file(date, local_folder, splunk_url, splunk_port, splunk_user, splunk_pass, tvm_config):
    s = SplunkLookup(splunk_url, splunk_port, splunk_user, splunk_pass, tvm_config)

    if not os.path.exists(os.path.dirname(local_folder)):
        os.makedirs(os.path.dirname(local_folder))

    ips = set()
    logins = set()
    min_ts = 0
    max_ts = 0

    with open(local_folder + date, "w") as fw:
        for rec in s.search_users(date):
            try:
                ip = rec["ip"]
                if ip:
                    ips.add(ip)

                login = rec["login"]
                if login:
                    logins.add(login)

                ts_str = rec["timestamp"]
                if ts_str:
                    ts = long(ts_str)
                    if ts < min_ts or min_ts == 0:
                        min_ts = ts
                    if ts > max_ts:
                        max_ts = ts

                line = "\t".join(k + "=" + str(v or "") for k, v in rec.iteritems()) + "\n"
                fw.write(line)
            except:
                logger.exception("Skip radius line %s", rec)

    logger.info("The number of processed unique IPs is %d", len(ips))
    logger.info("The number of processed unique logins is %d", len(logins))

    if len(ips) > 0:
        from_dt = ts_to_datetime_str(min_ts)
        to_dt = ts_to_datetime_str(max_ts)
        logger.info("Processed entries from %s to %s", from_dt, to_dt)

    else:
        raise Exception("Radius data is not available")
