#!/usr/bin/env python
#! -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import division
import sys
import os
import codecs
import argparse
import logging
import toml
import re
import pdb
import pickle
import traceback
import urlparse
from collections import defaultdict
from mapreducelib import MapReduce, Record
import datetime as dt
from datetime import datetime as dtdt
from pecheny.mrdef import defaults
from pecheny.commons import table_exists
from pecheny.moncommons import push_to_razladki
import itertools

__file__ = os.path.abspath(__file__)    # fix __file__ handling
_file_ = os.path.basename(__file__)     # in python 2
PATH = ['12.1620.705']
names = {
'searchextchrome',
'startextchrome',
'homesearchextchrome',
'vbch',
'altsearchchrome',
}

def info(type, value, tb):
    """
    Used for debugging, not mine.
    """
    if hasattr(sys, 'ps1') or not sys.stderr.isatty():
       # we are in interactive mode or we don't have a tty-like
       # device, so we call the default hook
        sys.__excepthook__(type, value, tb)
    else:
        import traceback, pdb
        # we are NOT in interactive mode, print the exception...
        traceback.print_exception(type, value, tb)
        print
        # ...then start the debugger in post-mortem mode.
        pdb.pm()

sys.excepthook = info

def determine_first_visit(timestamps, clids):
    for e, clid in enumerate(clids):
        if clids[e] in names:
            return timestamps[e], clids[e]
    return 0, 'no'

def deutf8ify(rec):
    # if isinstance(rec, mapreducelib.SubkeyedRecord):
        key, subkey, value = rec.key, rec.subkey, rec.value
        if not isinstance(key, unicode):
            key = key.decode('utf8', errors='replace')
        if not isinstance(subkey, unicode):
            subkey = subkey.decode('utf8', errors='replace')
        if not isinstance(value, unicode):
            value = value.decode('utf8', errors='replace')
        return Record(key, subkey, value)
    # elif isinstance(rec, str):
        # rec = rec.decode('utf8',errors='replace')
    # return rec

def utf8ify(rec):
    # if isinstance(rec, mapreducelib.SubkeyedRecord):
        if isinstance(rec.key, unicode):
            rec.key = rec.key.encode('utf8')
        if isinstance(rec.subkey, unicode):
            rec.subkey = rec.subkey.encode('utf8')
        if isinstance(rec.value, unicode):
            rec.value = rec.value.encode('utf8')
        return rec
    # elif isinstance(rec, unicode):
        # rec = rec.encode('utf8')
    # return rec

def normalize_clid(clid):
    """
    Throws away all non-numeric parts of clid and
    appendices like -001
    """
    result = clid
    result = re.sub(r'\-[0-9]+','',result)
    result = re.sub(r'[^0-9]','',result)
    return result

def ymd(date):
    return date.strftime('%Y%m%d')

def y_m_d(date):
    return date.strftime('%Y-%m-%d')

def avg(list_): 
    """
    Importing numpy for just np.mean was uncool.
    """
    try:
        return sum(list_) / float(len(list_))
    except:
        return 0

def avg_norm(list_):
    """
    Given a list, returns its normalized average.
    Normalization occurs by throwing away top & bottom 5%
    """
    list_ = sorted(list_)
    margin = int(len(list_) * 0.05)
    return avg(list_[margin:-margin])

def getvalue(string, val):
    """
    Gets value of val= from a tab-separated string.
    """
    rv = ""
    tabs = string.split("\t")
    for k in tabs:
        if k[0:(len(val) + 1)] == val + "=":
            rv = k[(len(val) + 1):]
    return rv

def make_reqrelev(rr):
    """
    Parses reqrelev to dict
    """
    candidates = rr.split(';')
    result = {}
    for candidate in candidates:
        if len(candidate.split('=')) > 1:
            result[candidate.split('=')[0]] = '='.join(
                candidate.split('=')[1:])
    return result

class Querypropsstat(object):
    """
    Filters sessions from uids in self.uids
    (allows y[0-9] and [0-9] form as input)
    """
    def __init__(self, uids, strict=False):
        """
        Set strict to True if you do not want to preprocess yandexuids.
        """
        DIGITS = set(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
        if not strict:
            self.uids = set([('y{}'.format(x) if x[0] in DIGITS else x) 
                for x in uids if len(x)>5])
        else:
            self.uids = uids
    def __call__(self, rec):
        rec = deutf8ify(rec)
        uid = rec.key
        if uid in self.uids:
            clid = "no"
            if ("service=www.yandex" in rec.value 
                and "ui=www.yandex" in rec.value 
                and "type=REQUEST" in rec.value):
                time = rec.subkey
                full_request = getvalue(rec.value, "full-request")
                fuid = getvalue(rec.value, "fuid")
                split_request = (full_request
                    .replace("&", "|")
                    .replace("?", "|")
                    .split("|"))
                for k in split_request:
                    if k[0:5] == "clid=":
                        clid = k[5:]
                if clid=='':
                    clid = 'no'
                reqrelev = make_reqrelev(
                    getvalue(rec.value, "reqrelev")
                    ) 
                isnav = reqrelev.get('is_nav', '')
                yield utf8ify(Record(
                    uid, time, '{}\tis_nav={}\tfuid={}'.format(
                            clid, isnav, fuid)))

class FirstMap(object):
    """
    Makes a table consisting of records:

    key = yandexuid
    subkey = ts of install
    value = productname \t user-region

    Only productnames present in self.names end up
    in the resulting table.
    """
    def __init__(self, names):
        self.names = names
    def parseparams(self, value):
        tabs = value.split('\t')
        result = {}
        for x in tabs:
            if len(x.split('=')) > 1:
                result[x.split('=')[0]] = '='.join(x.split('=')[1:])
            else:
                result[x] = 'SINGLE'
        return defaultdict(lambda: '', result)
    def parsevars(self, vars):
        commas = vars.split(',')
        result = {'clids': []}
        for x in commas:
            if len(x.split('=')) > 1:
                key = x.split('=')[0]
                value = '='.join(x.split('=')[1:])
                if key.startswith('-'):
                    key = key[1:]
                if key.startswith('clid'):
                    result['clids'].append(value)
                else:
                    result[key] = value
            else:
                result[x] = 'SINGLE'
        return defaultdict(lambda: '', result)
    def __call__(self, rec):
        from collections import defaultdict
        rec = deutf8ify(rec)
        params = self.parseparams(rec.value)
        if params['type'] == 'TECH' and 'vars' in params:
            vars = self.parsevars(params['vars'])
            if vars['dayuse']=='0' and vars['productname'] in self.names:
                yield utf8ify(
                    Record(
                        rec.key,
                        rec.subkey,
                        vars['productname'] + '\t' + params['user-region']
                        ))

def retry(maxretries):
    """
    Retries a function if it throws an exception
    max. %maxretries% times
    then raises
    """
    logger = logging.getLogger(_file_[:-3])
    def retry_decorator(func):
        def retry_wrapper(*args, **kwargs):
            logger.info('Trying {} with arguments: {}, {}.'
                .format(func.__name__, args, kwargs))
            usedtries = 0
            while usedtries < maxretries:
                try:
                    reply = func(*args, **kwargs)
                    if reply:
                        logger.info('Reply: {}'.format(
                            reply.replace('\n','\\n')))
                    else:
                        logger.info('Reply: none')
                    return reply
                except:
                    logger.info(traceback.format_exc())
                    usedtries += 1
        retry_wrapper.ac = func.func_code.co_argcount
        return retry_wrapper
    return retry_decorator

@retry(5)
def run_map(op, src, dst, append=False):
    """
    Allows for more comfortable mapreducelib usage;
    compensates for sakura instabilities by retrying
    """
    kwargs = {
        'srcTable': src,
        'dstTable': dst,
        'appendMode': append,
    }
    MapReduce.runMap(op, **kwargs)

def make_test_record(line):
    """
    Makes a SubkeyedRecord out of a (tab-separated) line.
    """
    tabs = line.rstrip().split('\t')
    return Record(tabs[0], tabs[1], '\t'.join(tabs[2:]))

def file_to_reclist(filename):
    """
    A function for testing/debugging purposes.
    Makes a SubkeyedRecord out of each line of file,
    then returns list of these records.
    """
    lines = codecs.open(filename, 'r', 'utf8').read().split('\n')
    reclist = []
    for line in lines:
        rec = make_test_record(line)
        if rec:
            reclist.append(rec)
    return reclist

def make_uids(reclist):
    """
    Returns dict of form:

    key = uid
    value = (productname, ts of install) for each distr product
        installed.
    """
    uids = defaultdict(lambda: [])
    for rec in reclist:
            tabs = rec.value.split('\t')
            productname = tabs[0]
            ts = int(rec.subkey)
            if productname:
                uids[rec.key].append((productname, ts))
    return dict(uids)

def make_data(recs):
    """
    Returns dict of form:
    
    key = uid
    value = (ts, clid) for each query of uid
    
    Assumes recs are sorted, will _not_ work on unsorted data.
    """
    def tryadd(data, element):
        data[element[0]] = element[1]
    data = {}
    previous = ''
    stack = []
    for rec in recs:
        key, subkey, value = (
            rec.key.decode('utf8', errors='replace'),
            rec.subkey.decode('utf8', errors='replace'),
            rec.value.decode('utf8', errors='replace'),
        )
        vtabs = value.split('\t')
        current = key
        ts = int(subkey)
        if current != previous:
            if len(stack) > 5:
                tryadd(data, (previous, stack))
            stack = []
        clid = vtabs[0]
        if not clid:
            clid = 'empty'
        stack.append([ts, clid])
        previous = current
    tryadd(data, (previous, stack))
    return dict(data)

def make_final_data(data, uids):
    """
    Returns dict of form:

    key = productname
    value = list of tuples
        (
        uid, 
        list of tuples (ts, clid) for queries before install,
        list of tuples (ts, clid) for queries after install,
        ts of install
        )

    throws away uids born less than a week before install
    and uids not present in uids dict.

    """
    final_data = defaultdict(lambda: [])
    for uid in data:
        visits = sorted(data[uid], key=lambda x: x[0])
        timestamps = [x[0] for x in visits]
        clids = [x[1] for x in visits]
        product = ''
        ts_of_install = 0
        ts_of_birth = 0
        try:
            product, ts_of_install = uids[uid][0]
            ts_of_birth = int(uid[-10:])
            if ((ts_of_install < ts_of_birth) 
                or (ts_of_install - ts_of_birth < (86400 * 7 - 1))):
                continue
        except KeyError:
            continue
        before = [(ts,clids[e]) for e, ts in enumerate(timestamps)
            if ts < ts_of_install and clids[e] != 'logged in']
        after = [(ts,clids[e]) for e, ts in enumerate(timestamps) 
            if ts > ts_of_install and clids[e] != 'logged in']
        final_data[product].append((uid, before, after, ts_of_install))
    return dict(final_data)

def count_product_stats(cstats):
    """
    Given a value of final_data, returns a list:
    [
        number of uids,
        list of numbers of queries on the week before install,
        list of numbers of queries on the week after install,
        list of diffs between the previous two
    ]
    """
    uids = len(cstats)
    befores = [len(tup[1]) for tup in cstats]
    afters = [len(tup[2]) for tup in cstats]
    diffs = [afters[i] - befores[i] 
            for i in range(len(afters))]
    return [uids, befores, afters, diffs]

def ttest_pvalue(a1, a2):
    try:
        return stats.ttest_ind(a1,
            a2, equal_var=False)[1]
    except RuntimeWarning:
        return 1

def tryparseyyyymmdd(string):
    try:
        return dtdt.strptime(string, '%Y%m%d')
    except:
        return None

def main():

    global _file_
    global __file__                         # to fix stupid
    __file__ = os.path.abspath(__file__)    # __file__ handling
    _file_ = os.path.basename(__file__)     # in python 2

    import requests
    from pecheny.moncommons import push_to_razladki
    import mapreducelib
    from mapreducelib import MapReduce, Record

    parser = argparse.ArgumentParser()
    parser.add_argument('--debug','-d',action='store_true')
    parser.add_argument('--date','-date',default=None, 
        help='Default date is (today - 8).')
    parser.add_argument('--config','-r',default=None,
        help='Default config file is %filename-without-extension%.toml')
    args = parser.parse_args()

    start = dtdt.now()

    # set up logging
    logger = logging.getLogger(_file_[:-3])
    formatter = logging.Formatter('%(asctime)s | %(message)s')
    ch = logging.StreamHandler()
    logger.setLevel(logging.DEBUG)
    if args.debug:
        ch.setLevel(logging.DEBUG)
    else:
        ch.setLevel(logging.CRITICAL)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    fh = logging.FileHandler('{}/logs/{}-{}.log'.format(
        os.path.dirname(__file__),_file_[:-3], start),
        encoding='utf8')
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    # load config
    with open('basic.toml','r') as f:
        config = toml.loads(f.read())
    os.chdir(os.path.dirname(__file__))
    with open('distribution.toml','r') as f:
        config.update(toml.loads(f.read()))
    if args.config is None:
        config.update(toml.loads(open(_file_[:-3]+'.toml').read()))
    else:
        config.update(toml.loads(open(args.config).read()))

    if args.date is None:
        processed_dates = set(
            [tryparseyyyymmdd(x)
            for x in
            open('distribution_installs_monitoring_weekprofit_new_dates')
            .read()
            .decode('utf8', errors='replace')
            .split('\n')])
        processed_dates = processed_dates - {None}
        initialdate = dtdt(2015, 8, 10)
        i = initialdate
        dates = set()
        while i < (dtdt.today() - dt.timedelta(days=8)):
            if i not in processed_dates:
                dates.add(i)
            i += dt.timedelta(days=1)

    else:
        dates = [dtdt.strptime(args.date.replace('-',''), '%Y%m%d')]

    defaults()
    config['debug'] = args.debug
    for date in sorted(dates):
        process_date(date, processed_dates, config)

def process_date(date, processed_dates, config):
    from mapreducelib import MapReduce, Record
    date_from = date - dt.timedelta(days=7)
    date_to = date + dt.timedelta(days=7)
    ts = int((date - dtdt(1970, 1, 1)).total_seconds())
    
    dates = []
    while date_from <= date_to:
        dates.append(date_from)
        date_from += dt.timedelta(days=1)

    logger = logging.getLogger(_file_[:-3])
    srctable = 'user_sessions/{}'.format(ymd(date))
    insttable = 'tmp/pers/avgprofit_{}_installs'.format(ymd(date))
    if len(list(MapReduce.getSample(insttable, count=1))) != 1:
        available = True
        for d in dates:
            sessions_check = 'user_sessions/{}'.format(ymd(d))
            logger.info('Checking table {}'.format(sessions_check))
            if not table_exists(sessions_check):
                logger.info('Table {} does not exist'
                    .format(sessions_check))
                available = False
        if available == False:
            logger.info('Cannot analyze due to '
                'inavailability of one or more tables :(')
            return

        firstmap = FirstMap(names)
        logger.info('Mapping from {} to {}'
            .format(srctable, insttable))
        run_map(firstmap, srctable, insttable)

    # populate uids
    logger.info('Getting uids from {}'
        .format(insttable))
    reclist = MapReduce.getSample(insttable, count=None)
    logger.info('Finished getting uids from {}'
        .format(insttable))
    uids = make_uids(reclist)

    if len(set(uids.keys())) == 0:
        logger.info('No uids (probably no sessions)')
        return
    pickle.dump(dict(uids),open('uids_{}.pkl'.format(ts),'wb'))

    count_requests = Querypropsstat(set(uids.keys()))
    dsttable = 'tmp/pers/avgprofit_{}_queries'.format(ymd(date))

    if not table_exists(dsttable):
        logger.info('{} does not exist already, mapping...'
            .format(dsttable))
        for date_ in dates:
            srctable = 'user_sessions/{}'.format(ymd(date_))
            run_map(count_requests, srctable, dsttable,
                    append=True)
    else:
        logger.info('{} already exists'
            .format(dsttable))

    logger.info('Sorting {}'
                    .format(dsttable))
    MapReduce.sortTable(dsttable)

    logger.info('Getting uids from {}'.format(dsttable))
    reclist = MapReduce.getSample(dsttable, count=None)
    logger.info('Making data...')
    data = make_data(reclist)

    pickle.dump(dict(data),open('data_{}.pkl'.format(ts),'wb'))
    logger.info('Counting stats..')

    final_data = make_final_data(data, uids)

    # pdb.set_trace()
    logger.info('Pickling data to final_data_{}.pkl'.format(ts))
    pickle.dump(dict(final_data),open('final_data_{}.pkl'.format(ts),'wb'))
    for group in final_data:
        stats = count_product_stats(final_data[group])
        avgdiff = avg_norm(stats[-1])

        desc = '{}_avgprofit_weekago'.format(group)
        value = avgdiff
        # push_to_razladki(config, 
                         # desc, 
                         # value,
                         # ts=ts)
        logger.info('{}\t{}\t{}\t{}'
                .format(desc, value, ts, dtdt.fromtimestamp(ts)))
    processed_dates.add(date)
    open('distribution_installs_monitoring_weekprofit_new_dates', 'w').write(
        '\n'.join(
            sorted(
                [ymd(x) for x in processed_dates])))

if __name__ == "__main__":
    main()
