# -*- coding: UTF-8 -*-

from nile.api.v1 import (clusters, Record, files)
from nile.api.v1 import datetime as nd
import copy
import sys
import json
import os
from getpass import getuser

username = getuser()
file_path = os.path.abspath(__file__)

class Reqans_log_plainer(object):
    """Init input tables and run map to plain reqans log

    usage:
    cluster = clusters.Hahn(pool = 'search-research_{}'.format(username))
    parser = reqans_parse.Reqans_log_plainer(cluster)
    parser.init_table('latest 5min') # or 'latest 1d' or  '2017-03-27..2017-03-28' or '2017-03-27'
        or '2017-03-30 10:00:00' or '2017-03-30 10:00:00..2017-03-30 15:05:00'
    table_path = parser.plain(client = 'distr_portal', sub_client = 'popup')
    """

    def __init__(self, cluster, geobase_path='/home/zaringleb/geobase.json', atom=True):
        self.cluster = cluster
        self.geobase = get_geobase(geobase_path)
        self.path_to_save = "//home/atom/{}/".format(username) if atom else "//tmp/"

    def init_table(self, date_range):
        pref = "//logs/atomfront-reqans-log/"
        if 'latest' in date_range:
            self.job = self.cluster.job()
            if '1d' in date_range:
                self.table = self.job.cumulative_table(pref + "1d")
            elif '5min' in date_range:
                self.table = self.job.cumulative_table(pref + "stream/5min")
        else:
            dates = date_range.split("..")
            if ':' in date_range:
                if len(dates) == 1:
                    date = dates[0].replace(" ", "T")
                else:
                    date = nile_range(dates[0], dates[1])
                self.cluster = self.cluster.env(templates=dict(date=date))
                self.job = self.cluster.job()
                self.table = self.job.table(pref + "stream/5min/@date")
            else:
                if len(dates) == 1:
                    date = dates[0]
                else:
                    date = "{{{}..{}}}".format(dates[0], dates[1])
                self.cluster = self.cluster.env(templates=dict(date=date))
                self.job = self.cluster.job()
                self.table = self.job.table(pref + "1d/@date")

    def plain(self, client, sub_client):
        mapper = Mapper(self.geobase, client, sub_client)
        self.table = self.table \
            .map(mapper, files=[files.LocalFile(file_path)]) \
            .put(self.path_to_save + "reqans_table")
        #return self.job
        self.job.run()
        return self.path_to_save + "reqans_table"

def datetime_range(start_datetime, end_datetime, minutes=5):
    """generate 5min tables"""
    start_datetime = nd.Datetime.from_iso(start_datetime)
    end_datetime = nd.Datetime.from_iso(end_datetime)

    while start_datetime <= end_datetime:
        yield start_datetime
        start_datetime = nd.next_datetime(start_datetime, scale='minutely', offset=minutes)

def nile_range(start_datetime, end_datetime):
    """format names 5min tables"""
    return '{{{}}}'.format(','.join(_.strftime('%FT%T') for _ in datetime_range(start_datetime, end_datetime)))

def get_params_from_request_url(request_url):
    """str to dict"""
    # depricated
    return {item.split("=", 1)[0]: item.split("=", 1)[1] for item in request_url.split("?", 1)[1].split("&")}

def bite(s):
    if s:
        return s.split()[0].lower()
    else:
        return s

def get_country(x, jgb):
    obj = jgb[x]
    if obj['type'] < 3:
        return bite(obj['iso_name'])
    if obj['type'] == 3:
        return bite(obj['iso_name'])
    for y in obj['path']:
        if jgb[y]['type'] == 3:
            return bite(jgb[y]['iso_name'])

def get_lr(referer):
    parsed = urlparse.urlparse(referer)
    qs = urlparse.parse_qs(parsed.query)
    if 'lr' in qs:
        return qs['lr'][0]
    return ''

def preprocess_geobase(z):
    x = copy.deepcopy(z)
    for k in ['id', 'type']:
        x[k] = int(x[k])
    x['path'] = [int(y) for y in x['path'].split(', ') if y]
    return x

def get_geobase(path):
    try:
        gb = json.load(open(path))
        geobase = {int(x['id']): preprocess_geobase(x) for x in gb}
    except Exception, e:
        sys.stderr.write('WARNING: unable get geobase from {}'.format(path))
        geobase = {}
    return geobase

class Mapper(object):
    """mapper for reqans log with geobase"""

    def __init__(self, geobase, client, sub_client):
        self.geobase = geobase
        self.client = client
        self.sub_client = sub_client

    def __call__(self, rows):
        for row in rows:
            try:
                if row['client'] == self.client:
                    try:
                        country = get_country(int(row['region']), self.geobase)
                    except Exception, e:
                        sys.stderr.write('GEO_ERROR: {}'.format(str(e)))
                        country = ''
                    user_agent = ([one['value'] for one in row["requestHeaders"] if one["name"] == "User-Agent"] or [""])[0]
                    props_json = json.loads(row['rest'].get('propsJson', '{}'))
                    browser = props_json.get('ua-traits', {}).get('BrowserName', '')
                    os = props_json.get('ua-traits', {}).get('OSFamily', '')
                    instp = props_json.get('product-profile', {}).get("instp", {}).keys()
                    for answer in row["rest"].get('answers', []):
                        if answer['name'] == self.sub_client:
                            aux_info_props = answer["auxInfo"]["propsJson"]
                            flt = json.loads(aux_info_props)["filtering-results"]
                            docs = answer.get('docs', [])
                            light_docs = [{one: doc[one] for one in ["link", "score", "bannerId"]} for doc in docs]
                            for doc in light_docs:
                                host, product = doc["link"].split("/")
                                doc['host'] = host
                                doc['product'] = product
                            yield Record(
                                client=row['client'],
                                sub_client=answer['name'],
                                flt=flt,
                                user_agent=user_agent,
                                reqid=row['requestId'],
                                timestamp=row['timestamp'],
                                lang=row['lang'],
                                exp=",".join([one['testId'] for one in row["experiments"]]),
                                lr=row['region'],
                                country=country,
                                docs=light_docs,
                                uuid=row.get('uuid', ''),
                                yuid=row.get('yandexUid', ''),
                                training=row['rest']['collectPoolMode'] * 1,
                                browser=browser,
                                os=os,
                                requestUrl=row['rest']["requestUrl"],
                                instp=instp,
                                aux_info_props=aux_info_props,
                                collectPoolMode=row['rest']['collectPoolMode']
                            )
            except Exception, e:
                sys.stderr.write('ERROR: {}'.format(str(e)))

def plot_data(df, prop, top=None, share=False, h_window=1, file_to_save=None):
    """plot DataFrame x:'time', y:prop"""
    import pandas as pd
    import matplotlib.pyplot as plt
    try:
        import seaborn as sns
    except Exception:
        pass

    df_g = df.groupby(['time', prop], as_index=False).agg({'count': sum})

    lines = df.groupby([prop], as_index=False).agg({'count': sum}).sort_values('count', ascending=False)[prop].values
    if top:
        lines = lines[:top]
    f, ax = plt.subplots(1, 1, sharey=True, sharex=True, figsize=(15, 5))
    labels = []
    if share:
        df_g_total = df_g.groupby('time', as_index=False).agg({'count': sum}).rename(columns={'count': 'count_total'})
        df_g = pd.merge(df_g, df_g_total, on='time')
        df_g['count'] = df_g['count']*1./df_g['count_total']
    for one in lines:
        df_tmp = df_g[df_g[prop] == one]
        label, = ax.plot(df_tmp['time'], df_tmp['count'].rolling(window=h_window, center=False).mean(), label=one, lw=2)
        labels.append(label)
    ax.legend(handles=labels, fontsize=15, markerscale=5, loc=2, bbox_to_anchor=(1,1), borderaxespad=2)
    if file_to_save:
        f.savefig(file_to_save)
