import time
import requests
import traceback
import itertools
import re
import urllib

from graphite.conductor import Conductor
from django.conf import settings
from graphite.logger import log

try:
    from graphite_api.intervals import Interval, IntervalSet
    from graphite_api.node import LeafNode, BranchNode
except ImportError:
    from graphite.intervals import Interval, IntervalSet
    from graphite.node import LeafNode, BranchNode


conductor = Conductor()


def conductor_glob(queries):
    result = set()
    for query in queries:
        parts = query.split('.')
        for (index, part) in enumerate(parts):
            if conductor.CONDUCTOR_EXPR_RE.match(part):
                hosts = conductor.expandExpression(part)
                hosts = [host.replace('.', '_') for host in hosts]

                if len(hosts) > 0:
                    parts[index] = hosts
                else:
                    parts[index] = [part]
            else:
                parts[index] = [part]
        result.update(['.'.join(p) for p in itertools.product(*parts)])
    return list(result)


class StockpileFinder(object):
    braces_re = re.compile('({[^{},]*,[^{}]*})')

    def _expand_braces_part(self, part):
        match = self.braces_re.search(part)
        if not match:
            return [part]

        result = set()

        startPos, endPos = match.span(1)
        for item in match.group(1).strip('{}').split(','):
            result.update(self._expand_braces_part(part[:startPos] + item + part[endPos:]))

        return list(result)

    def expand_braces(self, query):
        parts = query.split('.')
        for (index, part) in enumerate(parts):
            parts[index] = self._expand_braces_part(part)

        result = set(['.'.join(p) for p in itertools.product(*parts)])
        return list(result)

    def find_nodes(self, query, reqkey, user):
        metricsearch = getattr(settings, 'STOCKPILE_SERVER', '127.0.0.1')

        queries = self.expand_braces(query.pattern)
        queries = conductor_glob(queries)

        result = []
        session = requests.Session()
        for query in queries:
            request = session.get('http://%s:7000/search?%s' % (metricsearch, urllib.urlencode({'query': query, 'user': user})))
            request.raise_for_status()

            result += request.text.split('\n')
        session.close()

        for metric in result:
            if not metric:
                continue

            if metric.endswith('.'):
                yield BranchNode(metric[:-1])
            else:
                yield LeafNode(metric, StockpileReader(metric, reqkey))


class StockpileReader(object):
    __slots__ = ('path', 'nodes', 'reqkey')

    def __init__(self, path, reqkey=''):
        self.nodes = [self]
        self.path = None

        if hasattr(path, '__iter__'):
            self.nodes = path
        else:
            self.path = path

        self.reqkey = reqkey

    def get_intervals(self):
        return IntervalSet([Interval(0, int(time.time()))])

    def fetch(self, startTime, endTime, user):
        log.info('DEBUG:stockpile_range:[%s] start = %s, end = %s' % (self.reqkey, startTime, endTime))

        query = self.get_query(startTime, endTime, user)
        log.info('DEBUG:stockpile_query:[%s] count = %d, query = %s' % (self.reqkey, len(query['metrics']), query))

        profilingTime = {
            'start': time.time()
        }

        session = requests.Session()
        try:
            request = session.post("http://%s:7000/query" % ''.join(getattr(settings, 'STOCKPILE_SERVER', ['127.0.0.1'])),
                                    query,
                                    headers={'Connection':'close'},
                                    stream=True)
            request.raise_for_status()
        except Exception as e:
            session.close()
            traceback.print_exc()
            log.exception("Failed to fetch data, got exception:\n %s" % str(e))
            return []

        profilingTime['fetch'] = time.time()

        data = {}

        blockpointsremaining = 0
        blockpoints = {}

        lines = request.iter_lines()

        version = lines.next()
        if version != "2":
            raise Exception("Unsupported fetch protocol version: [%s]" % version)
        step = int(lines.next())

        lines_from_disk = -1

        for line in lines:
            if not line or line.startswith("# "):
                continue

            line = line.split(' ')

            if blockpointsremaining is 0:
                path = line[0]
                blockpointsremaining = int(line[1])
                if len(line) > 2:
                    line_from_disk = int(line[2])
                    if lines_from_disk < 0:
                        lines_from_disk = 0
                    lines_from_disk += line_from_disk
                blockpoints = {}
                data[path] = blockpoints
            else:
                blockpointsremaining -= 1
                ts = int(line[0])
                value = float(line[1])
                blockpoints[ts] = value

        session.close()
        profilingTime['parse'] = time.time()

        timeInfo = (startTime - startTime % step, endTime - endTime % step, step)

        result = []

        for node in self.nodes:
            data.setdefault(node.path, {})

            result.append((
                node,
                (
                    timeInfo,
                    [
                        data[node.path].get(ts, None)
                        for ts in xrange(timeInfo[0], timeInfo[1] + 1, step)
                        ]
                )
            ))

        profilingTime['convert'] = time.time()

        log.info(('DEBUG:stockpile_time:[{req_key}] '
                  'fetch = {fetch_time}, '
                  'parse = {parse_time}, '
                  'convert = {convert_time}, '
                  'from_disk: {from_disk}').format(
            req_key=self.reqkey,
            fetch_time=profilingTime['fetch'] - profilingTime['start'],
            parse_time=profilingTime['parse'] - profilingTime['fetch'],
            convert_time=profilingTime['convert'] - profilingTime['parse'],
            from_disk=lines_from_disk
        ))

        if self.path:
            return result[0][1]

        return result

    def get_query(self, startTime, endTime, user):
        return {
            'metrics': [node.path for node in self.nodes],
            'from': startTime,
            'until': endTime,
            'req': self.reqkey,
            'user': user,
        }


import graphite.readers

graphite.readers.MultiReader = StockpileReader
