"""
Generic library functions.
"""

# pylint: disable=R0903,C0111
# R0903: too-few-public-methods
# C0111: missing-docstring

from contextlib import contextmanager
from os import statvfs, getpid, unlink, makedirs, chmod, chown
from os.path import isfile, isdir, exists, getmtime
import os.path
from time import time, sleep, localtime
import json
from logging import getLogger, WARNING
import logging.handlers
import re
import socket
import itertools
from collections import defaultdict
import random
from functools import partial, wraps
import signal
import copy
from fcntl import flock, LOCK_EX, LOCK_SH, LOCK_UN

import requests

import ttv.config as config

# global logger for this library
log = getLogger('ttv')  # pylint: disable=C0103


#############################################################################
#
# logging and alerting
#

config.var('ALERT_RATELIMIT_COUNT', 3) # mute after this many repeats
config.var('ALERT_RATELIMIT_WINDOW', 30*60.0) # reset rate limiter every how often

class RateLimitFilter(logging.Filter):
    """A rate-limiting logging.Filter.  Allows up to <count> logs from
    the same file:line every <window> seconds.  At the risk of stating
    the obvious, this only works within a single process. Having it
    work across multiple processes would (likely) involve having a
    persistent server as a chokepoint.
    """
    def __init__(self, count=ALERT_RATELIMIT_COUNT, window=ALERT_RATELIMIT_WINDOW):  # pylint: disable=E0602
        # can't super logging.Filter in python 2.6 cause it's old-school
        #super(RateLimitFilter, self).__init__()
        logging.Filter.__init__(self)
        self.count = count
        self.window = window
        self.seen_at = defaultdict(int)
        self.seen_count = defaultdict(int)
        
    def filter(self, record):
        key = (record.pathname, record.lineno)

        if self.seen_at[key] < time() - self.window:
            self.seen_at[key] = time()
            self.seen_count[key] = 1
        else:
            self.seen_count[key] += 1

        if self.seen_count[key] < self.count:
            return 1
        elif self.seen_count[key] == self.count:
            record.msg += ' (rate limiting)'
            return 1
        else:
            return 0


# global logger that emails alerts to the relevant parties (with rate-limiting)
# ALERT_ADDRS should be a comma-separated list of emails
#
# Use it like this:
#
#   ttv.alerter.warn('Mr. Watson, come here! I need you!')

config.var('ALERT_MAILTO', 'tools@justin.tv,ops@justin.tv')

class TTVSMTPHandler(logging.handlers.SMTPHandler):
    """SMTPHandler that sets the subject line"""
    def getSubject(self, record):  # pylint: disable=C0103
        return '%s alert from %s' % (record.levelname, record.pathname)


# define ttv.alerter, along with its options
# (this setup stuff should really be hidden better)
alerter = getLogger('ttv_alertlogger')  # pylint: disable=C0103
alerter_handler = TTVSMTPHandler(mailhost='localhost',  # pylint: disable=C0103
                           fromaddr='ops@justin.tv',
                           toaddrs=ALERT_MAILTO.split(','),  # pylint: disable=E0602
                           subject=None)
alerter_hostname = socket.gethostname()  # pylint: disable=C0103
alerter_formatter = logging.Formatter(  # pylint: disable=C0103
    fmt=alerter_hostname + ': %(pathname)s:%(lineno)s: function %(funcName)s: %(message)s')  # pylint: disable=C0103
alerter_handler.setFormatter(alerter_formatter)
alerter.addHandler(alerter_handler)
alerter.addFilter(RateLimitFilter())


#############################################################################
#
# math
#

KILO = DKILO = 1000.0
MEGA = DMEGA = DKILO*DKILO
GIGA = DGIGA = DKILO*DMEGA
TERA = DTERA = DKILO*DGIGA

BKILO = 1024.0
BMEGA = BKILO*BKILO
BGIGA = BKILO*BMEGA
BTERA = BKILO*BGIGA


def greatest_monotone_sequence(lst, compare):
    """Given a list of numeric values and a compare function, return
    the indices of the list elements representing the endpoints of the
    monotone sequence with the greatest value difference between the
    endpoints--intuitively, the tallest continuous rise or fall. If
    compare is operator.gt, for instance, it will be the greatest
    fall.
    """
    big_start = big_end = 0  # indices of end points of greatest run
    cur_start = cur_end = 0  # indices of end points of current run

    for i in range(1, len(lst)):  # skip 0 so that we can compare to i - 1
        if compare(lst[i-1], lst[i]):  # if they compare, then we're on a run
            cur_end = i
        else:
            if compare(lst[cur_start] - lst[cur_end],
                       lst[big_start] - lst[big_end]):
                big_start, big_end = cur_start, cur_end
            cur_start = cur_end = i

    if compare(lst[cur_start] - lst[cur_end],
               lst[big_start] - lst[big_end]):
        big_start, big_end = cur_start, cur_end
    
    return big_start, big_end


#############################################################################
#
# lists, strings, sorting, dates, etc.
#

def uniquify(lis):
    """Uniquify a list. (Elements must be hashable.)"""
    return list(set(lis))


def chunkify(iterable, chunksize):
    """break iterable into chunks"""
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, chunksize))
        if not chunk:
            return
        yield chunk


def hybrid_sorted_keyfunc(x):  # pylint: disable=C0103
    """Add spaces on all alpha-num and num-alpha transitions.
    Sorted will do the right thing with a mixed string-num list.
    Thus app2 < app12 in the sort order...
    If it looks like an array (but not a string), sort on element 0...
    """
    if hasattr(x, '__iter__'):
        return hybrid_sorted_keyfunc(x[0])
    if x is None:
        return None
    x = str(x)
    x = re.sub(r'(\D)(\d)', r'\g<1> \g<2>', x)
    x = re.sub(r'(\d)(\D)', r'\g<1> \g<2>', x)
    key = [int(a) if a.isdigit() else a for a in x.split()]
    return key
hsk = hybrid_sorted_keyfunc  # pylint: disable=C0103


def hybrid_sorted_by_attr(objs, field, reverse=False):
    """Hybrid-sort a list of objs by attribute"""
    def hybrid_sorted_by_attr_keyfunc(obj):
        return hybrid_sorted_keyfunc(getattr(obj, field))
    return sorted(objs, key=hybrid_sorted_by_attr_keyfunc, reverse=reverse)


def hostname_sorted_keyfunc(x):
    """Sort by domain name first, then by uqdn. If not everything
    contains a '.', then don't sort at all.
    """
    return hsk('.'.join(reversed(x.split('.'))))


def sort_hostnames(names):
    if not all(('.' in name for name in names)):
        return names
    return sorted(names, key=hostname_sorted_keyfunc)


def apply_in_chunks(func, iterable, chunksize, progress, throttle):
    """Apply func to successive chunksize (or smaller) subsequences of iterable
    Optionally report progress every progress seconds
    Optionally throttle to ensure >= throttle seconds every chunk
    """
    if progress:
        prog = Progress(secs=progress)
    if throttle:
        throt = Throttle()
    
    for subseq in chunkify(iterable, chunksize):
        func(subseq)
        if progress:
            if prog.click(len(subseq)):
                yield prog.count()
        if throttle:
            throt.throttle(throttle)
    yield prog.count()


#############################################################################
#
# files, directories, and OS functions
#

@contextmanager
def autoremove(filename):
    """Use as follows:
    
    with autoremove(function_that_returns_a_filename()) as f:
        blah blah blah

    # by the time you get here, the file will have been deleted
    
    """
    try:
        yield filename
    finally:
        if filename is not None and isfile(filename):
            unlink(filename)


def osfunc_ignore(func, *args, **kwargs):
    """Run func on args, log.debug() on OSError."""
    try:
        func(*args, **kwargs)
    except OSError as e:
        log.debug('%s(%s %s) failed: %s', func, args, kwargs, e)


def unlink_ignore(filename):
    if isfile(filename):
        osfunc_ignore(unlink, filename)


def makedirs_ignore(path, mode=0777):
    if not isdir(path):
        osfunc_ignore(makedirs, path, mode)


def chmod_ignore(path, mode):
    if exists(path):
        osfunc_ignore(chmod, path, mode)


def chown_ignore(path, user, group):
    if exists(path):
        osfunc_ignore(chown, path, user, group)


def percent_full(path):
    """% full for the disk that path is on"""
    s = statvfs(path)
    return (1 - float(s.f_bavail) / s.f_blocks) * 100


JTV_WILDCARD = '192.16.71.174'

def dnsresolve(hostname):
    """does this hostname resolve?"""
    try:
        ipaddr = socket.gethostbyname(hostname)
    except socket.gaierror:
        return None
    else:
        if ipaddr == JTV_WILDCARD:
            return None
    return ipaddr

#############################################################################
#
# requests package
#

def requests_wrapper(func, url, quiet=False, **kwargs):
    """Simplified requests() library call: either it works completely,
    or it logs and returns None.
    """
    funcname = func.__name__.upper()
    if not quiet:
        log.debug('%s %s %s', funcname, url, kwargs)

    try:
        r = func(url, **kwargs)
    except (requests.ConnectionError, socket.gaierror) as e:
        if not quiet:
            log.error('%s %s %s failed: %s', funcname, url, kwargs, e)
        return None
    except requests.Timeout as e:
        if not quiet:
            log.error('%s %s %s timed out: %s', funcname, url, kwargs, e)
        return None

    # Non-ok response codes leave r as False, but not None. So test against None!
    if r is None:
        if not quiet:
            log.error('%s %s %s failed', funcname, url, kwargs)
        return None

    if not quiet:
        log.debug('r: %s', r)

    if not r.ok:
        if not quiet:
            log.error('%s %s %s failed: %s (%s)', funcname, url, kwargs, r.status_code, r.text)
        return None

    return r


# add a requests.move() here...?


#############################################################################
#
# control flow
#

def sleep_until_secs(secs_list):
    # this assumes that epoch starts at :00 in all timezones...
    # important to get fractional seconds...else bad things may happen
    secs_now = time() % 60  # this is a float
    wait = min([(s - secs_now) % 60 for s in secs_list])
    log.debug('sleep_until_secs: sleeping %s', wait)
    sleep(wait)


class Throttle(object):
    """Timing. Ensure minimum time passed. Use like so:

    t = Throttle()
    while True:
        do_stuff()
        t.throttle(20)

    This ensures that the loop executes at most once per 20 seconds.

    """

    def __init__(self):
        self.timestamp = time()

    def throttle(self, seconds):
        """Ensure that at least 'seconds' seconds have passed since last interaction."""
        wait = self.timestamp + seconds - time()
        if wait > 0:
            log.debug('throttle: sleeping %s', wait)
            sleep(wait)
        self.timestamp = time()
        return


#############################################################################
#
# UI
#

class Progress(object):
    """Progress. click() returns True every so often. At that point, check count()

    p = Progress()
    while True:
        do_stuff()
        if p.click() is not None:
            print p.count(), '...'

    """

    def __init__(self, secs=10):
        self.secs = secs
        self.next = time() + secs
        self.cnt = 0

    def click(self, incr=1):
        """Every so often, return the count"""
        self.cnt += incr
        if time() < self.next:
            return False
        else:
            self.next = time() + self.secs
            return True

    def count(self):
        return self.cnt


class Spinner(Progress):
    """Spinny thing. Returns a spinner phase every so often.

    s = Spinner()
    while True:
       do_stuff()
       if s.click():
           sys.stdout.write(s.phase())
           sys.stdout.flush()

    """

    SPINNER = r'/-\|'
    
    def __init__(self, secs=0.25):
        super(Spinner, self).__init__(secs)
        self.spincount = 0
        
    def click(self, incr=1):
        if not super(Spinner, self).click(incr):
            return False
        else:
            self.spincount += 1
            self.spincount %= len(self.SPINNER)
            return True
            
    def phase(self):
        return '\b' + self.SPINNER[self.spincount]


PALECOLORS = ['%02x%02x%02x' % (r, g, b)
              for r, g, b in itertools.product(range(250, 180, -9), range(250, 180, -9), range(250, 180, -9))]
random.shuffle(PALECOLORS)

def hashcolor(token, palette):
    return palette[hash(token) % len(palette)]


#############################################################################
#
# decorators and other funny business
#


def dither(base, fraction=0.1):
    """Return a function returning a uniformly distributed float between base * (1 +/- fraction).
    Goes well with the cache functions below."""
    return partial(random.uniform, base * (1 - fraction), base * (1 + fraction))


class TimeoutException(Exception):
    pass


@contextmanager
def timelimit(secs):
    """Throw a TimeoutException if a block takes longer than secs.

    Usage:
    
    try:
        with timelimit(10):
            do_stuff
    except TimeoutException:
        print 'fail'
    """
    def handler(signum, frame):  # pylint: disable=W0613
        raise TimeoutException()
    old = signal.signal(signal.SIGALRM, handler)
    signal.setitimer(signal.ITIMER_REAL, secs)
    try:
        yield
    finally:
        signal.setitimer(signal.ITIMER_REAL, 0)
        signal.signal(signal.SIGALRM, old)


def timeout(secs):
    """Throw a TimeoutException if a function takes longer than secs.

    @timeout(3.0)
    def foofunc():
        do_stuff
    """

    def deco(func):

        @wraps(func)
        def newfunc(*args, **kwargs):
            with timelimit(secs):
                return func(*args, **kwargs)
        return newfunc
    
    return deco


def cached(secs, deepcopy=False, dont_cache_none=False, use_stale_non_none=False):
    """Cache the return values from this function for secs seconds.
    Note this only works right if it's a pure function with no
    side-effects.  Also, this may not work on built-in functions.

    If deepcopy is True, a deepcopy of the return value is made before
    returning it.

    If dont_cache_none is True, then None return values will not be
    cached.
    
    If use_stale_non_none is True, then not only will None return
    values not be cached, but a stale cached non-None value will be
    returned in preference to a None value.
    
    Notes:

    - Not thread safe. The pid is included for debugging only.

    - secs can be a number, or a 0-arg function returning a number.

    - if _bypass_cache=True, then any cached value will be bypassed for
      that call (hopefully the original function doesn't have a param
      named _bypass_cache!)

    - if the function returns a complex data structure and you want
      fresh, use deepcopy=True

    Usage:

    @cached(60)
    def f(foo):
        # some function whose value depends on foo and may change
        # occasionally but not frequently

    or

    @cached(dither(60))
    """

    # pylint: disable=C0103,W0212

    def deco(f):

        @wraps(f)
        def g(*args, **kwargs):
            if not hasattr(f, '_cache'):
                setattr(f, '_cache', defaultdict(lambda: {'value': None, 'expires': 0}))
            key = json.dumps({'args': args, 'kwargs': kwargs, 'pid': getpid()})
            now = time()
            if key not in f._cache or now >= f._cache[key]['expires'] or kwargs.get('_bypass_cache', False):
                value = f(*args, **kwargs)
                expire_time = now
                expire_time += secs if hasattr(secs, '__float__') else secs()
                if value is None and use_stale_non_none:
                    value = f._cache[key]['value']
                elif value is None and dont_cache_none:
                    pass
                else:
                    f._cache[key]['value'] = value
                    f._cache[key]['expires'] = expire_time
            else:
                value = f._cache[key]['value']
        
            if deepcopy:
                return copy.deepcopy(value)
            else:
                return value

        return g
    return deco


def cached_http_getter(secs):
    """Returns a function that can be used to cache static HTTP stuff for cache_secs seconds.

    Usage:

    get = http_getter(10)
    foo = get('http://kimchi-api.justin.tv/rest/objects/?type=router')
    bar = get('http://kimchi-api.justin.tv/rest/objects/?type=router')  # same as foo
    # 11 seconds go by
    baz = get('http://kimchi-api.justin.tv/rest/objects/?type=router')  # may be different

    get_forever = http_getter(dither(10))  # another possibility...
    get_forever = http_getter(float('Inf'))  # and another possibility...
    ...
    """
    @cached(secs)
    def getter(url):
        return requests_wrapper(requests.get, url)
    return getter


def objcached(secs):
    """Class decorator for object caching. The cache key is a json of
    the args plus the pid (in case of multiprocessing). All objects
    created with the same args share state, so if one changes, they
    all change.

    Notes: 
    
    - Not thread safe. The pid is included for debugging only.
    - secs can be a number, or a 0-arg function returning a number.
    - It would be nice to eliminate the code duplication with cached (above)
    
    Usage:
    
    @objcached(10)
    class Foo(object):
       def __init__(self, *args, **kwargs):
           ...

    All new Foo objects with the same args will share state, and they
    will all get refreshed every 10 seconds.

    @objcached(function_returning_a_number)
    
    As above, but refresh interval is determined by calling the
    function on a per-refresh basis.

    Suggestion for a function: partial(random.uniform, 108, 132)

    Or equivalently, dither(120, 0.1)

    (This gives you cache dither, if you want it.)
    """
    def decorator(cls):
        # pylint: disable=W0212
        orig_init = cls.__init__
        cls._cache = {}

        def __init__(self, *args, **kwargs):
            key = json.dumps({'args': args, 'kwargs': kwargs, 'pid': getpid()})
            now = time()
            if key in cls._cache and now < cls._cache[key]['expires']:
                self.__dict__ = cls._cache[key]['obj']
            else:
                orig_init(self, *args, **kwargs)
                expire_time = now
                expire_time += secs if hasattr(secs, '__float__') else secs()
                cls._cache[key] = {'obj': self.__dict__, 'expires': expire_time}

        cls.__init__ = __init__
        return cls

    return decorator


def borg(cls):
    """Class decorator to give something the borg nature... (google
    "borg design pattern" for more.)

    WARNING: Forever is a long time. If you're dealing with something
    that could ever change due to external forces, use objcached() with
    <= 20 minutes or something instead.
    """
    return objcached(float('Inf'))(cls)


#############################################################################
#
# json cache file etc.
#


@contextmanager
def flocked(filedes, oper):
    flock(filedes, oper)
    yield
    flock(filedes, LOCK_UN)


class JsonCacheFile(object):
    """A simple object for saving stuff via json"""
    
    def __init__(self, filename):
        self.filename = filename
        dname = os.path.dirname(self.filename)
        if not isdir(dname):
            makedirs(dname, 0777)

    def read(self, maxage='+Inf'):
        if not isfile(self.filename):
            return None
        if time() - getmtime(self.filename) > maxage:
            return None
        with open(self.filename, 'r') as cache:
            with flocked(cache, LOCK_SH):
                contents = cache.read()
                try:
                    return json.loads(contents)
                except ValueError:
                    return None

    def write(self, data):
        with open(self.filename, 'a+') as cache:
            with flocked(cache, LOCK_EX):
                cache.seek(0)
                cache.truncate()
                cache.write(json.dumps(data))
