import requests
import grequests
import json
import time

BUCKYTOOLS_HOST_A = "graphite-whisper-a-024726f35d19f52bd.prod.us-west2.justin.tv"
BUCKYTOOLS_HOST_B = "graphite-whisper-b-07142056f2dbbb933.prod.us-west2.justin.tv"
BUCKYTOOLS_PORT = 4242
GRAPHITE_WEB_HOST = "graphite-web.internal.justin.tv"
GRAPHITE_WEB_PORT = 80

DISCOVERY_A = "http://{endpoint}:{port}/hashring".format(endpoint=BUCKYTOOLS_HOST_A, port=BUCKYTOOLS_PORT)
DISCOVERY_B = "http://{endpoint}:{port}/hashring".format(endpoint=BUCKYTOOLS_HOST_B, port=BUCKYTOOLS_PORT)


class GraphiteCluster():
    def __init__(self, discovery_endpoint):
        init_url = discovery_endpoint
        r = requests.get(init_url)
        cluster_data = json.loads(r.text)
        self.hash_ring = GraphiteHashRing()
        for host in cluster_data["Nodes"]:
            host, port, instance = host.split(":")
            n = GraphiteNode(host, port, instance)
            self.hash_ring.add_node(n)

    # Given some metric path `PATH`, does a lookup of children, i.e. expands `PATH.*`
    def list_children(self, path):
        # If a top level path, just use `*` as the query
        if path == "":
            query = "*"
        else:
            query = path + ".*"
        url = "http://{h}:{p}/metrics/find?query={q}".format(h=GRAPHITE_WEB_HOST, p=GRAPHITE_WEB_PORT, q=query)
        r = requests.get(url)
        children = json.loads(r.text)
        # If there are any leaf nodes, get staleness data for each
        # Do this manually instead of calling `stat_metric` so
        # we can use grequests and do all the requests in parallel
        stat_urls = []
        for child in children:
            if child['leaf'] == 1:
                stat_urls.append(
                    "http://{h}:{p}/metrics/{m}".format(
                        h=self.get_metric_node(child['id']).host,
                        p=BUCKYTOOLS_PORT,
                        m=child['id']
                    )
                )
        if len(stat_urls) != 0:
            def exception_handler(request, exception):
                raise RuntimeError("failed to request {r}: {e}".format(
                    r=request.url,
                    e=str(exception)
                ))
            stat_reqs = (grequests.head(url) for url in stat_urls)
            resps = grequests.map(stat_reqs, exception_handler=exception_handler)
            print stat_urls
            print resps
            # inject staleness data into info about leaf metrics
            for resp in resps:
                if resp.status_code == 404:
                    raise KeyError("metric not found for call to stat_metric")
                metric_stat = json.loads(resp.headers['X-Metric-Stat'])
                metric = metric_stat['Name']
                staleness = time.time() - metric_stat['ModTime']
                for child in children:
                    if child['id'] == metric:
                        child['staleness'] = staleness
        return children



    # Given a metric path, returns the owner node
    def get_metric_node(self, metric):
        return self.hash_ring.get_node(metric)

    # Gets metadata for a single metric. If the path does
    # not point to a metric, None will be returned
    # this includes the case where `metric` points to a valid
    # location in the tree that is NOT a leaf
    def get_metric_info(self, metric):
        resp = requests.get("http://{h}:{p}/metrics/find?query={m}&format=treejson".format(
            h=GRAPHITE_WEB_HOST,
            p=GRAPHITE_WEB_PORT,
            m=metric
        ))
        # resp.raise_for_status()
        results = json.loads(resp.text)
        if len(results) != 1 or results[0]['leaf'] != 1:
            return None
        return results[0]

    # Gets stats about a metric as reported by buckytools
    # This can be used to get last updated time of the metric
    def stat_metric(self, metric):
        owner_host = self.get_metric_node(metric)
        url = "http://{h}:{p}/metrics/{m}".format(
            h=owner_host,
            p=BUCKYTOOLS_PORT,
            m=metric
        )
        resp = requests.head(url)
        if resp.status_code == 404:
            raise KeyError("metric {m} does not exist".format(m=metric))
        # Add a staleness number that indicates number of seconds since last modified
        stats = json.loads(resp.headers['X-Metric-Stat'])
        now = int(time.time())
        # Add time since last update
        stats['ModTimeDelta'] = now - stats['ModTime']
        return stats


    # Attempts to delete a metric from the cluster.
    # If the metric does not exist, or the metric
    # is not a leaf node, None is returned
    # Else, info for the metric deleted will be returned
    def delete_metric(self, path):
        # if not self.path_is_leaf(path):
        #     # TODO: Make this a better exception
        #     raise Exception("the given path is not a leaf in the tree")
        info = self.get_metric_info(path)
        if info is None:
            return None
        node = self.get_metric_node(path)
        resp = requests.delete("http://{h}:{p}/metrics/{m}".format(
            h=node.host,
            p=BUCKYTOOLS_PORT,
            m=info['id']
        ))
        # Raise an error for 4XX or 5XX response
        resp.raise_for_status()
        return info

    def move_metric(self, src, dst):
        # Make sure that the source metric exists
        info = self.get_metric_info(src)
        # Raise an error if the specified metric does not exist
        if info is None:
            raise KeyError("metric '{m}' does not exist".format(m=src))
        # The 'move' is basically just fetching a metric from the src
        # location, stashing the data, then forwarding it on to the node
        # that will be the new owner, then delete from the original host
        # Get the hosts owning the metrics, and format some URLs
        src_host = self.get_metric_node(src).host
        src_url = "http://{h}:{p}/metrics/{m}".format(
            h=src_host,
            p=BUCKYTOOLS_PORT,
            m=src)
        dst_host = self.get_metric_node(dst).host
        dst_url = "http://{h}:{p}/metrics/{m}".format(
            h=dst_host,
            p=BUCKYTOOLS_PORT,
            m=dst)
        # Get the src data
        src_resp = requests.get(src_url)
        # Raise errors if there are issues
        src_resp.raise_for_status()
        print(len(src_resp.content))
        # Backfill the metric with the new name on the new host
        dst_resp = requests.post(
            dst_url,
            headers={'Content-Type': 'application/octet-stream'},
            data=src_resp.content
        )
        dst_resp.raise_for_status()
        # Delete the metric from the previous host
        # Will raise an exception if there is an error
        self.delete_metric(src)


# A GraphiteHashRing that allows for computing where metrics live on the hash ring
class GraphiteHashRing():

    def __init__(self):
        self.nodes = []
        self.ring_entries = []
        self.replicas = 100

    def add_node(self, node):
        self.nodes.append(node)
        for i in range(100):
            replica_key = "{r}-{k}".format(r=i, k=node.key_value())
            replica_position = self.compute_ring_position(replica_key)
            ring_entry = RingEntry(replica_position, node)
            self.insert_ring_entry(ring_entry)

    def compute_ring_position(self, key):
        raw_hash = fnv32a(key)
        digested_hash = (raw_hash >> 16) ^ (raw_hash & 0xFFFF)
        return digested_hash

    # Return the first node in the ring that has a position larger than
    # the position of the key we are doing a lookup on
    def get_node(self, key):
        pos = self.compute_ring_position(key)
        i = 0
        for e in self.ring_entries:
            if e.position >= pos:
                break
            i += 1
        # loop back to 0 if the position of the key
        # is larger than the largest position in the ring
        return self.ring_entries[i % len(self.ring_entries)].node

    def insert_ring_entry(self, ring_entry):
        i = self.bisect_right(ring_entry)
        self.ring_entries.insert(i, ring_entry)

    def bisect_right(self, ring_entry):
        i = 0
        for e in self.ring_entries:
            if e.compare(ring_entry) > 0:
                break
            i += 1
        return i


# RingEntry objects are low-level structs making up
# the innerds of a hash ring
class RingEntry():
    def __init__(self, position, node):
        self.position = position
        self.node = node

    def compare(self, e):
        if self.position < e.position:
            return -1
        if self.position > e.position:
            return 1
        if self.node.host < e.node.host:
            return -1
        if self.node.host > e.node.host:
            return 1
        if self.node.port < e.node.port:
            return -1
        if self.node.port > e.node.port:
            return 1
        return 0


# Create a GraphiteNode holding an addressable host and port,
# and a unique identifier instance string to hash on
class GraphiteNode():
    def __init__(self, host, port, instance):
        self.host = host
        self.port = port
        self.instance = instance

    def key_value(self):
        if self.instance == "":
            return self.host
        else:
            return self.instance


# HELPER FUNCTIONS
def fnv32a(str):
    hval = 0x811c9dc5
    fnv_32_prime = 0x01000193
    uint32_max = 2 ** 32
    for s in str:
        hval = hval ^ ord(s)
        hval = (hval * fnv_32_prime) % uint32_max
    return hval

