# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
import json
import requests
import logging
import time

logger = logging.getLogger(__name__)


class SplunkHECSender(object):
    """
    Class for sending data to infrasec Splunk slucter via Http Event Collector.
    Token most be described on cluster side config.
    Data can be three types: string, dict or list of dicts.

    Usage:
    1. object = SplunkHECSender("TOKEN")
    2. object.send_data(data_variable) # data_variable can be string, dict or list of dicts.

    """

    def __init__(
        self,
        token,
        src_host=None,
        index=None,
        source=None,
        sourcetype=None,
        timestamp=None,
        hec_host="hatch.yandex.net",
        hec_verify_ssl=False,
    ):

        # HEC attributes
        self.hec_host = hec_host
        self.token = token
        self.request_url = "https://%s/services/collector" % self.hec_host
        auth_header = "Splunk %s" % self.token
        self.auth_header = {"Authorization": auth_header}

        # Fields attributes
        self.src_host = src_host
        self.index = index
        self.source = source
        self.sourcetype = sourcetype
        # Set timestamp
        if timestamp is not None:
            try:
                float(timestamp)
                self.timestamp = timestamp
            except ValueError:
                self.timestamp = None
                logger.error("Timestamp not in epoch format!")
        else:
            self.timestamp = timestamp

        # Verify SSL
        self.hec_verify_ssl = hec_verify_ssl
        if not self.hec_verify_ssl:
            requests.packages.urllib3.disable_warnings()

    @staticmethod
    def _get_current_time():
        timestamp = str(time.time())
        return timestamp

    def _get_event_time(self):
        if self.timestamp is None:
            event_time = self._get_current_time()
        else:
            event_time = self.timestamp
        return event_time

    @staticmethod
    def _slice_list(array, size):
        array_size = len(array)
        slice_size = array_size / size
        remain = array_size % size
        result = list()
        iterator = iter(array)
        for i in range(size):
            result.append([])
            for j in range(slice_size):
                result[i].append(iterator.next())
            if remain:
                result[i].append(iterator.next())
                remain -= 1
        return result

    def _send_post_request(self, prep_data):
        try:
            post_request = requests.post(
                self.request_url,
                data=prep_data,
                headers=self.auth_header,
                verify=self.hec_verify_ssl,
            )
            response_json = json.loads(str(post_request.text))
            response = post_request.text
            if "text" in response_json:
                if response_json["text"] != "Success":
                    logger.error("Request post failed, response: %s", response)
                    return False
                else:
                    return True
        except Exception as err:
            logger.error("Exception raised: %s", err)
            return None

    def send_data(self, data):
        # Prepare metadata
        post_event = dict()
        post_batch_events = list()
        if self.src_host is not None:
            post_event.update({"host": self.src_host})
        if self.index is not None:
            post_event.update({"index": self.index})
        if self.source is not None:
            post_event.update({"source": self.source})
        if self.sourcetype is not None:
            post_event.update({"sourcetype": self.sourcetype})

        # If input is list of dicts, then batch events in one request, else send single events
        if type(data) is list:
            # Prepare string
            for i in data:
                if type(i) is dict:
                    event_time = self._get_event_time()
                    res = {"event": i, "time": event_time}
                    res.update(post_event)
                    res = json.dumps(res)
                    post_batch_events.append(res)

            if len(post_batch_events) > 1000:
                slice_number = len(post_batch_events) / 1000
                post_batch_events_sliced = self._slice_list(
                    post_batch_events, slice_number
                )
                for part in post_batch_events_sliced:
                    post_batch_events_part = "\n".join(part)
                    send_req = self._send_post_request(post_batch_events_part)
                    logger.info(
                        "Count of events > 1000, sending 1k events per one post request"
                    )
            else:
                logger.info("Count of events < 1000, sending in single post request")
                post_batch_events = "\n".join(post_batch_events)
                send_req = self._send_post_request(post_batch_events)

        elif type(data) is dict or type(data) is str:
            post_event.update({"event": data, "time": self._get_event_time()})
            post_event = json.dumps(post_event)
            logger.info(
                "Single event passed in str or dict found, sending post request ..."
            )
            send_req = self._send_post_request(post_event)
        else:
            send_req = False

        if send_req:
            return True
        else:
            return False
