#!/usr/bin/env python
#! -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import division
from __future__ import print_function
import sys
import os
import re
import json
import codecs
import logging
import toml
import pdb
import math
import argparse
import traceback
import mapreducelib
import threading
import subprocess
import shlex
from time import sleep
try:
    import thread
except ImportError:
    import _thread as thread
from mapreducelib import MapReduce, Record
import urlparse
from collections import defaultdict, Counter, namedtuple
import datetime
from pecheny.mrdef import defaults
from pecheny.moncommons import push_to_razladki


class UTC(datetime.tzinfo):

    def utcoffset(self, dt):
        return datetime.timedelta(0)

    def tzname(self, dt):
        return "GMT"

    def dst(self, dt):
        return datetime.timedelta(0)

    def __repr__(self):
        return "UTC"


class Moscow(datetime.tzinfo):

    def utcoffset(self, dt):
        return datetime.timedelta(hours=3)

    def tzname(self, dt):
        return "Europe/Moscow"

    def dst(self, dt):
        return datetime.timedelta(0)

    def __repr__(self):
        return "Europe/Moscow (UTC+3)"

utc = UTC()
moscow = Moscow()
epoch = datetime.datetime(1970, 1, 1, tzinfo=utc)


def aware_now():
    return datetime.datetime.now(moscow)


def aware_fromtimestamp(ts):
    return datetime.datetime.fromtimestamp(ts, moscow)


def aware_strptime(string, fmt):
    if not '%z' in fmt and not '%Z' in fmt:
        return datetime.datetime.strptime(string, fmt).replace(tzinfo=moscow)
    return datetime.datetime.strptime(string, fmt)


def diff_sec(dt1, dt2, toint=True):
    diff = (dt1 - dt2).total_seconds()
    if toint:
        return int(round(diff))
    return diff


def dt_to_ts(dt):
    return int(round((dt - epoch).total_seconds()))


def arbitrary_round(num, precision):
    return round(num / float(precision)) * precision


def dt_round(dt, precision):
    if isinstance(precision, datetime.timedelta):
        precision = int(round(precision.total_seconds()))
    assert isinstance(precision, int)
    ts = dt_to_ts(dt)
    newts = arbitrary_round(ts, precision)
    return datetime.datetime.fromtimestamp(newts, moscow)


def deutf8ify(rec):
    if isinstance(rec, mapreducelib.SubkeyedRecord):
        key, subkey, value = rec.key, rec.subkey, rec.value
        if not isinstance(key, unicode):
            key = key.decode('utf8', errors='replace')
        if not isinstance(subkey, unicode):
            subkey = subkey.decode('utf8', errors='replace')
        if not isinstance(value, unicode):
            value = value.decode('utf8', errors='replace')
        return Record(key, subkey, value)
    elif isinstance(rec, str):
        rec = rec.decode('utf8', errors='replace')
    return rec


def utf8ify(rec):
    if isinstance(rec, mapreducelib.SubkeyedRecord):
        if isinstance(rec.key, unicode):
            rec.key = rec.key.encode('utf8')
        if isinstance(rec.subkey, unicode):
            rec.subkey = rec.subkey.encode('utf8')
        if isinstance(rec.value, unicode):
            rec.value = rec.value.encode('utf8')
        return rec
    elif isinstance(rec, unicode):
        rec = rec.encode('utf8')
    return rec


def tryint(string):
    try:
        return int(string)
    except:
        return -1


def tabulate(*args):
    return '\t'.join(map(format, args))


def get_score(j, n=0):
    try:
        return j['answer']['client-results'][1]['docs'][n]['score']
    except KeyError:
        return None


def get_reqid(j):
    rr1 = re.compile(r'%22wprid%22%3A%22([0-9a-zA-Z\-]*)%22')
    try:
        access = j['params']['access_entry']['access']
    except KeyError:
        pass
    if rr1.search(access):
        return rr1.search(access).group(1)
    return ''


def get_client(j):
    try:
        return j['params']['client']
    except KeyError:
        return ''


def get_clients_scores(j):
    result = {}
    if 'answer' in j and 'client-results' in j['answer']:
        for x in j['answer']['client-results']:
            try:
                result[x['name']] = x['docs'][0]['score']
            except KeyError:
                continue
        return result


def atommap(rec):
    rec1 = deutf8ify(rec)
    ts = rec1.subkey[:10]
    j = json.loads(rec1.value)
    try:
        ts = aware_fromtimestamp(int(ts))
    except ValueError:
        ts = -1
    if not ts == -1:
        ts_rounded = dt_round(ts, datetime.timedelta(minutes=30))
        if get_clients_scores(j) and get_reqid(j):
            for client in get_clients_scores(j):
                yield utf8ify(Record(
                    format(dt_to_ts(ts_rounded)),
                    format(client),
                    format(get_clients_scores(j)[client]))
                )


def atomreduce(key, recs):
    key = deutf8ify(key)
    i = 0
    cnt = defaultdict(lambda: Counter())
    for rec in recs:
        cnt[rec.subkey][float(rec.value)] += 1
    for client in cnt:
        for ckey in cnt[client]:
            yield utf8ify(Record(key + ' ' + client,
                                 format(ckey), format(cnt[client][ckey])))


def get_lastts():
    ts = '20160117'
    if os.path.isfile('atomlog_score_last_timestamp'):
        with open('atomlog_score_last_timestamp') as f:
            ts = f.read().decode('utf8').rstrip()
    return aware_strptime(ts, '%Y%m%d')


def set_lastts(ts):
    with open('atomlog_score_last_timestamp', 'w') as f:
        f.write(format(ts.strftime('%Y%m%d')))


def tstable(table):
    try:
        return aware_strptime(table.split('/')[-1], '%Y%m%d')
    except ValueError:
        return datetime.datetime(1970, 1, 1, tzinfo=utc)


def get_srctables(lb=None, ub=None, alltables=None):
    if not alltables:
        alltables = get_alltables()
    if not lb:
        lb = get_lastts()
    if not ub:
        ub = datetime.datetime(2099, 1, 1, tzinfo=moscow)
    result = [x for x in alltables if tstable(x) > lb and tstable(x) <= ub]
    return result


def get_alltables():
    alltables = MapReduce.getTablesInfo('mobilesearch_answer/*')
    alltables = sorted([x.name for x in alltables])
    return alltables


def table_exists(table):
    ti = MapReduce.getTableInfo(table)
    return ti.size > 0


def choose_interval(num, intervals):
    return [interval for interval in intervals
            if num >= interval[0] and num <= interval[1]][0]


def counter_quantile(counter, quantile):
    keys = sorted([x for x in counter if counter[x] != 0])
    kmapper = {}
    mover = 0
    mover_prev = 0
    for k in keys:
        mover = mover_prev + counter[k] - 1
        kmapper[(mover_prev, mover)] = k
        mover_prev = mover + 1
    length = sum(counter.values()) - 1  # not gonna work with zero counters
    if length <= 0:
        return 0
    target = length * quantile
    if int(target) == target:
        return kmapper[choose_interval(target, kmapper)]
    else:
        return ((kmapper[choose_interval(math.floor(target), kmapper)] +
                 kmapper[choose_interval(math.ceil(target), kmapper)]) / 2.0)


def main():
    global __file__                         # to fix stupid
    __file__ = os.path.abspath(__file__)    # __file__ handling
    _file_ = os.path.basename(__file__)     # in python 2
    import arrow

    parser = argparse.ArgumentParser()
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--config', default=None)
    parser.add_argument('--datefrom', default=None)
    parser.add_argument('--dateto', default=None)
    args = parser.parse_args()
    start = dt_to_ts(aware_now())

    logger = logging.getLogger(_file_[:-3])
    formatter = logging.Formatter('%(asctime)s | %(message)s')
    ch = logging.StreamHandler()
    logger.setLevel(logging.DEBUG)
    if args.debug:
        ch.setLevel(logging.DEBUG)
    else:
        ch.setLevel(logging.CRITICAL)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    fh = logging.FileHandler('{}/logs/{}-{}.log'.format(
        os.path.dirname(__file__), _file_[:-3], start),
        encoding='utf8')
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    # load config
    with open('basic.toml', 'r') as f:
        config = toml.loads(f.read())
    os.chdir(os.path.dirname(__file__))
    with open('distribution.toml', 'r') as f:
        config.update(toml.loads(f.read()))
    if args.config is None:
        try:
            config.update(toml.loads(open(_file_[:-3] + '.toml').read()))
        except:
            pass
    else:
        config.update(toml.loads(open(args.config).read()))

    defaults()
    MapReduce.useDefaults(server=config['mr_server'])
    alltables = get_alltables()
    logger.info('all tables loaded')
    if not(args.datefrom and args.dateto):
        lastts = get_lastts()
        srctables = get_srctables(lb=lastts)
        while len(srctables) > 0:
            t = process_date(srctables[0], logger, config)
            processed_ts = tstable(srctables[0])
            if t and processed_ts > get_lastts():
                set_lastts(processed_ts)
            srctables = get_srctables()
        logger.info("No new data. Latest counted ts is {}"
                    .format(lastts))
        sys.exit(0)
    else:
        lb = aware_strptime(args.datefrom, '%Y%m%d')
        ub = aware_strptime(args.dateto, '%Y%m%d')
        srctables = get_srctables(lb=lb, ub=ub)
        for srctable in srctables:
            t = process_date(srctable, logger, config)


def process_date(srctable, logger, config):
    logger.info('Source table is {}'.format(srctable))
    ts = srctable.split('/')[-1]
    midtable1 = 'tmp/pers/atomlog{}'.format(ts)
    midtable2 = 'tmp/pers/atomlog{}map'.format(ts)
    dsttable = 'tmp/pers/atomlog{}reduce'.format(ts)

    if not table_exists(midtable1):
        env = os.environ.copy()
        env['MR_USER'] = 'tmp'
        logger.info(os.environ['PATH'])
        runstring = ('fetchlogs -s sakura00:8013 '
                     '-i {} '
                     '-o {}').format(srctable, midtable1)
        logger.info('Running `{}`'.format(runstring))
        subprocess.call(shlex.split(runstring), env=env)

    if table_exists(midtable1):
        success = False
        while not success:
            try:
                logger.info('Mapping from {} to {}'.format(
                    midtable1, midtable2))
                MapReduce.runMap(atommap, srcTable=midtable1,
                                 dstTable=midtable2)
                success = True
            except:
                logger.error(traceback.format_exc())

    if table_exists(midtable2):
        success = False
        while not success:
            try:
                logger.info('Reducing from {} to {}'.format(
                    midtable2, dsttable))
                MapReduce.runReduce(atomreduce, srcTable=midtable2,
                                    dstTable=dsttable)
                success = True
            except:
                logger.error(traceback.format_exc())

    if table_exists(dsttable):
        counters = defaultdict(lambda: defaultdict(lambda: Counter()))
        for rec in MapReduce.getSample(dsttable, count=None):
            rec1 = deutf8ify(rec)
            counters[int(rec1.key.split(' ')[0])][rec1.key.split(' ')[1]][
                float(rec1.subkey)] = int(rec1.value)
        for timestamp in counters:
            for client in counters[timestamp]:
                for quantile in [0.1 * x for x in xrange(1, 10)]:
                    desc = 'atomlog_score_{}_{}'.format(client, quantile)
                    value = counter_quantile(counters[timestamp][client],
                                             quantile)
                    logger.info('Pushing to razladki: {}, {}'
                                .format(desc, value))
                    push_to_razladki(config, desc, value, ts=timestamp)
        return True
    else:
        logger.critical('Something went wrong.'
                        .format(ts))
        return False


if __name__ == "__main__":
    main()
