# -*- coding: utf-8 -*-

import json
import datetime
import os
import time
import logging
from collections import Counter
import argparse
import pandas as pd
from nile.api.v1 import (filters as nf, aggregators as na, extractors as ne, clusters, Record)
import nile
from yasmapi import GolovanRequest

PROPS = ['client', 'sub_client']
EVENTTYPES = ["show", "close", "click", "trueinstall"]
PATH = "//home/zaringleb/curve_pic_storage/"


def plot_pic_b(df, picname, name, logger, hist=False, rate_dict=None):
    import matplotlib
    if not hist:
        matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import seaborn as sns

    # Calc scale (nx,ny) in "fasegrid"
    nxs = []
    clients = df["client"].unique()
    for client in clients:
        #if df.loc[df["client"] == client, 'show'].values[-1] > min_show:
        nxs.append(len(df.loc[df["client"] == client, "sub_client"].unique()))
    ny = len(clients)
    nx = max(nxs)

    #gridspec_kw = {"bottom":0.2, "wspace":0.2} # if ny == 1 else {}
    f, axs = plt.subplots(ny, nx, sharey=True, sharex=True, figsize=(15,15*ny/nx)) #, gridspec_kw=gridspec_kw) #sharey=True
    plt.suptitle(name, size=20)
    logger.info("fasegrid scale: {}x{}".format(nx,ny))
    i = 0
    j = 0

    # Plot "fasegrid"
    clients.sort()
    for client in clients:
        pp = (df["client"] == client)
        sub_clients = df.loc[pp, "sub_client"].unique()
        sub_clients.sort()
        for sub_client in sub_clients:
            #print(client,sub_client)
            ax = axs[j][i]
            ax.set_title(sub_client, fontsize=15)
            #ax.set_xlabel(client, fontsize=12)
            if i == 0:
                ax.set_ylabel(client, fontsize=16)   #, fontsize=5
            t = df[pp & (df["sub_client"] == sub_client)]
            max_shows = t["show"].max()

            x = (t["show"]/max_shows)
            logger.debug("client={}, sub_client={}, len(t)={}".format(client, sub_client, len(t)))
            line_click, = ax.plot(x, t["click"]/t["click"].max(), lw=2, color = "blue", label = 'click')
            line_close, = ax.plot(x, t["close"]/t["close"].max(), lw=2, color = "red", label = 'close')
            line_trueinstall, = ax.plot(x, t["trueinstall"]/t["trueinstall"].max(), lw=2, color = "green", label = 'trueinstall')
            #ax.plot([0,1], [0,1], lw=1.5, color = "grey")
            epselon = 0.00001
            if hist:
                try:
                    rate = rate_dict[(client, sub_client)]['rate']
                except:
                    logger.info("client={}, sub_client={}: unknown rate".format(client,sub_client))
                    rate = 0
                ax.plot([rate, rate], [0,1], color = "grey", lw = 4)
                #   else:
                #       pl = 1
                    #ax.plot([pl,pl+epselon], [0,1], color = "black", lw =1.1)
                    #ax.text(pl, 0.3*(p-0.25), str(p), fontsize=9)
            ax.set_xlim([0,1])
            ax.set_ylim([0,1])
            #ax.set_xlim([0,1])
            #ax.set_ylim([0,1])
            #ax.grid(True)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            #ax.set_xticklabels([0,0.2,0.4,0.6,0.8,1], fontsize=4)
            #ax.set_yticklabels([0,0.2,0.4,0.6,0.8,1], fontsize=4)
            #ax.get_xaxis().set_ticks([])
            #ax.get_yaxis().set_ticks([])
            #ax.legend(handles=[line_click, line_close, line_install], fontsize=9, markerscale=2, loc=4)
            i += 1
        j += 1
        i = 0

    f.savefig(picname)
    if hist:
        f.savefig('//home/zaringleb/curve_pic_storage/training.png')
    else:
        f.savefig('//home/zaringleb/curve_pic_storage/prod.png')

def prepare_df(df):
    min_show = 100000
    df_sum = df.groupby(PROPS, as_index = False).agg({'show':sum})
    df_sum = df_sum.rename(columns={'show' : "show_sum"})
    df_sum = df_sum[df_sum["show_sum"] > min_show]
    df = pd.merge(df, df_sum, on=PROPS, how='inner')
    df.sort_values("score", inplace=True, ascending=False)
    df = df.groupby(PROPS + ['training'], as_index = False).apply(group_func)
    return df

def group_func(df):
    df.sort_values("score", inplace=True, ascending=False)
    for one in EVENTTYPES:
        df[one] = df[one].cumsum()
    return df

def get_rate(date, clients_sub_clients):
    from time import time
    host = "ASEARCH"
    period = 86400
    et = int(date.strftime("%s")) + 24*60*60 # next day
    st = et
    signals = []
    signal_dict = {}
    for client, sub_client in clients_sub_clients:
        tags = "itype=atom;prj=web-mobreport-atom-search;tier={}".format(client.replace("_", "-"))
        if sub_client:
            signal_show = "unistat-{}-Shows_dmmm".format(sub_client)
        else:
            signal_show = "unistat-Shows_dmmm"
        signal_threshold = "unistat-{}-LastFilteringStageThreshold_dmmm".format(sub_client)
        signal_dict["{}:{}".format(tags,signal_show)] =  (client, sub_client, 'show')
        signal_dict["{}:{}".format(tags,signal_threshold)] =  (client, sub_client, 'threshold')
        signals.append("{}:{}".format(tags,signal_show))
        signals.append("{}:{}".format(tags,signal_threshold))
    for timestamp, values in GolovanRequest(host, period, st, et,  signals, explicit_fail=True):
        pass
        #stat.append(values.values())
        #print(timestamp, values)
    rate_dict = {}
    for one in values:
        tmp = signal_dict[one]
        if (tmp[0], tmp[1]) in rate_dict:
            rate_dict[(tmp[0], tmp[1])][tmp[2]] = values[one]
        else:
            rate_dict[(tmp[0], tmp[1])] = {tmp[2] : values[one]}

    for key in rate_dict:
        #print(key)
        if sum(rate_dict[key].values()) == 0:
            #print(key)
            rate_dict[key]['rate'] = 0
        else:
            rate_dict[key]['rate'] = rate_dict[key]['show']*1.0/(rate_dict[key]['show'] + rate_dict[key]['threshold'])
    return rate_dict

def process_table(date, logger):
    username = 'zaringleb'
    cluster = clusters.Hahn(pool = 'search-research_{}'.format(username))
    try:
        job = cluster.job()
        table = job.table("//home/atom/zaringleb/eventtype_tables/table_{}".format(date)) \
                .project(*(PROPS+EVENTTYPES+['score']), training=ne.custom(lambda x: 1*('.training' in x), 'distr_obj')) \
                .groupby(*(PROPS+['training', 'score'])) \
                .aggregate(**{one:na.sum(one) for one in EVENTTYPES}) \
                .put("//home/atom/zaringleb/Curve/grouped_{}".format(date))
        job.run()
        logger.debug('process_table done')
        return table
    except nile.nodes.table.MissingSourceTablesError as e:
        logger.info(e)

def get_past_dates():
    with open(PATH + 'done_dates.txt') as f:
        dates = f.readlines()
        dates = [date.strip() for date in dates]
    return dates

def put_past_date(date):
    with open(PATH + 'done_dates.txt', 'a') as f:
        f.write(date + "\n")

def process_date(date, logger):
    logger.info(date)
    grouped_table = process_table(date, logger)
    logger.info('grouped_table: {}'.format(grouped_table))
    if grouped_table:
        df = grouped_table.read().as_dataframe()
        logger.debug('len(df): {}'.format(len(df)))
        df = prepare_df(df)
        logger.debug('plotting prod')
        plot_pic_b(df[df['training'] == 0], 'prod_{}.png'.format(date), 'prod_{}'.format(date), logger)
        logger.debug('get rates')
        df_sum = df.groupby(PROPS, as_index = False).agg({'show':sum})
        rate_dict = get_rate(date, df_sum[['client', 'sub_client']].values)
        logger.debug('plotting train')
        plot_pic_b(df[df['training'] == 1], 'training.png'.format(date), 'traininig_{}'.format(date), logger, hist=True, rate_dict=rate_dict)
        put_past_date(str(date))

def main():
    # Parse args
    parser = argparse.ArgumentParser()
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--log_to_stdout', action='store_true')
    args = parser.parse_args()

    # Make logger
    logger = logging.getLogger('curve')
    logger.setLevel(logging.DEBUG)
    logging_level = logging.DEBUG if args.debug else logging.INFO
    log_format = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
    fh = logging.FileHandler(PATH + 'my_log.log')
    fh.setFormatter(log_format)
    fh.setLevel(logging_level)
    logger.addHandler(fh)
    if args.log_to_stdout:
        ch = logging.StreamHandler()
        ch.setFormatter(log_format)
        ch.setLevel(logging_level)
        logger.addHandler(ch)

    # Curves
    logger.info('Start')
    start_date = datetime.datetime.strptime('2016-11-04', "%Y-%m-%d")
    delta = (datetime.datetime.now() - start_date).days
    dates_todo = [(start_date + datetime.timedelta(days=i)).date() for i in range(delta)]
    past_dates = get_past_dates()
    dates_todo = [date for date in dates_todo if str(date) not in past_dates]
    for date in dates_todo:
        print(date)
        process_date(date, logger)
    logger.info('Done')

if __name__ == "__main__":
    main()
