#!/usr/bin/env python
# coding: utf-8

# # Расчет MDE для метрик cofe на основе мониторинговых выборок.
#
# Дашборд с результатами https://dash.yandex-team.ru/tsclrvbd5hhyp
#
# Поддерживается расчет count и ratio метрик (для них требуется указания feature2 - числителя дроби)
#
# Указания номера control  выборки обязательно. Treatment - опционально (при его отсутствии делается поправка для дисперсии в $\sigma = \sigma_0 \sqrt2$.
#
# Для добавления новых метрик необходимо добавить его фичи и выборку по которой считать в список metrics_to_run

# ## Мониторинги используемые при расчетах
#
# ### [EXPERIMENTS-31314] Мониторинг Мобильные карты UI Android
# [144208] Android Контроль 1  [144209] Android Контроль 2
#
# ### [EXPERIMENTS-32763] Мониторинг Мобильные карты UI iOS
# [144238] iOS Контроль 1, [144239] iOS Контроль 2
#
# ### [EXPERIMENTS-24058] Мониторинг Карты.Touch
#   [97489] Контроль 1 ( testing — скриншот )
#   [97491] Контроль 2 ( testing — скриншот )
#
# ### [EXPERIMENTS-14763] Мониторинг Карты Геопоиск
#   [48144] Контроль 1
#   [48145] Контроль 2
#
# ### Большой эксперимент в RTX с Яндекс Районом https://st.yandex-team.ru/EXPERIMENTS-29760
# [135685] Большой эксперимент в RTX с Яндекс Районом. Контроль <br />
# [135686] Большой эксперимент в RTX с Яндекс Районом <br />
#
# ### Мониторинг Колдунщик карт. Геопоиск без пересаливания  https://st.yandex-team.ru/EXPERIMENTS-3704
# [182560] Контроль 1 <br />
# [182561] Контроль 2 <br />
#!/usr/bin/env python
# coding: utf-8

import sys
import os
import json
import pandas as pd
import numpy as np
import requests
from yql.api.v1.client import YqlClient
from nile.api.v1 import clusters, Record

def main(in1, in2, in3, mr_tables, token1=None, token2=None, param1=None, param2=None, html_file=None):
    metrics_to_run = in1
    yt_token = token1
    yql_token = token2

    YQL_CLIENT = YqlClient(token=yql_token)
    def run_yql_df(query_text, title='collect cofe buckets', client=YQL_CLIENT):
        result = client.query(query_text, syntax_version=1, title='%s YQL' % title).run()
        return result.full_dataframe

    def execute_query(query, cluster, alias, token, timeout=600):
        """
        this one is for CHYT
        """
        proxy = "http://{}.yt.yandex.net".format(cluster)
        s = requests.Session()
        url = "{proxy}/query?database={alias}&password={token}".format(proxy=proxy, alias=alias, token=token)
        resp = s.post(url, data=query, timeout=timeout)
        resp.raise_for_status()
        rows = resp.content.decode().strip().split('\n')
        return rows

    def results_to_df(result):
        df = pd.DataFrame([i.split('\t') for i in result])
        df.columns = df.loc[0].tolist()
        df = df.iloc[1:]
        return df

    def get_from_ch(feature, testid, slicehash, yt_token = yt_token):
        query = """
                SELECT
                    "testid",
                    "bucket_num",
                    "bucket_value",
                    "value",
                    "slice",
                    "feature",
                    "ts"
                ,toDate(ts) as dt
                FROM "//home/geoadv/ayunts/store_cofe_buckets"
                WHERE testid = '{}'
                and feature = '{}'
                and slice = '{}'
                format TabSeparatedWithNames
                """.format(testid,feature,slicehash)
        cluster = 'hahn' #<-Имя кластера
        alias = "*maps_datalens_chyt"
        token = yt_token

        result = execute_query(query=query, cluster=cluster, alias=alias, token=token)
        return results_to_df(result)


    # ### Функция для расчета MDE


    def get_feature_df(control, feature1, metric_name,
                    feature2 = None, treatment = None,
                    slicehash =  'd41d8cd98f00b204e9800998ecf8427e',
                    testid_pct = 0.02, service = 'desktop', alpha = 0.01, power = 0.8,
                    sample_sizes_list = [0.02, 0.04, 0.05, 0.1, 0.2, 0.3, 0.4]):

        """
        For count metrics only feature1 is requiered to be specified.
        feature1 is a numerator and feature2 is denominator for ratio metrics.
        slicehash is used for filtering service, the general one is set  by default.
        Formula to get std - https://a.yandex-team.ru/arc/trunk/arcadia/quality/ab_testing/cofe/python/metrics/fetcher.py#L229
        """

        has_treatment, is_count_metric = treatment is not None, feature2 is None

        def get_ready_pvt(feature, test_group, slicehash):
            dataset = get_from_ch(feature, test_group, slicehash)
            try:
                dataset.bucket_value = dataset.bucket_value.astype(int)
            except:
                dataset.bucket_value = dataset.bucket_value.astype(float)
            dataset.bucket_value = dataset.bucket_value.fillna(0)
            mm =  dataset.groupby('dt').bucket_value.sum().reset_index()
            ok_days =  set(mm.query("bucket_value>5").dt.tolist())

            days_to_check = [i for i in sorted(dataset.dt.unique(), reverse=True) if i in ok_days and '-01-0' not in i]

            dataset.query('dt in @days_to_check', inplace=True)
            pvt = dataset.pivot_table(index='bucket_num',
                                        columns='dt',
                                        values='bucket_value',
                                        aggfunc='sum').fillna(0)
            pvt = pvt[days_to_check].cumsum(axis=1)
            return pvt

        control_df = get_ready_pvt(feature1, control, slicehash)

        if has_treatment:
            if is_count_metric:
                pvt_ctrl = control_df.copy()
                pvt_trtm = get_ready_pvt(feature1, treatment, slicehash)

                common_columns = list(set(pvt_ctrl.columns) & set(pvt_trtm.columns))
                pvt_ctrl = pvt_ctrl[common_columns]
                pvt_trtm = pvt_trtm[common_columns]
            else:
                pvt_ctrl_1 = control_df.copy()
                pvt_ctrl_2 = get_ready_pvt(feature2, control, slicehash)

                pvt_trtm_1 = get_ready_pvt(feature1, treatment, slicehash)
                pvt_trtm_2 = get_ready_pvt(feature2, treatment, slicehash)
                common_ctrl = set(pvt_ctrl_1.columns) & set(pvt_ctrl_2.columns)
                common_trtm = set(pvt_trtm_1.columns) & set(pvt_trtm_2.columns)
                common_columns = list(common_ctrl&common_trtm)

                pvt_ctrl = pvt_ctrl_2[common_columns]/pvt_ctrl_1[common_columns]
                pvt_trtm = pvt_trtm_2[common_columns]/pvt_trtm_1[common_columns]
        else:
            if is_count_metric:
                pvt_ctrl = control_df.copy()
            else:
                pvt_ctrl_1 = control_df.copy()
                pvt_ctrl_2 = get_ready_pvt(feature2, control, slicehash)
                common_columns = list(set(pvt_ctrl_1.columns)&set(pvt_ctrl_2.columns))
                pvt_ctrl = pvt_ctrl_2[common_columns]/pvt_ctrl_1[common_columns]

        if is_count_metric:
            control_val = pvt_ctrl.sum()
        else:
            control_val = pvt_ctrl_2.sum()/pvt_ctrl_1.sum()

        values = {}

        # логика расчета взята из исходников cofe bhttps://a.yandex-team.ru/arc/trunk/arcadia/quality/ab_testing/cofe/python/metrics/fetcher.py#L229
        #alpha_z, power_z = norm.ppf(1-alpha/2), norm.ppf(power)
        #z = alpha_z + power_z
        z = 0.8416212335729143+2.5758293035489004

        if has_treatment:
            days_to_check  = sorted(list(set(control_df.columns)&set(pvt_trtm.columns)), reverse=True)
        else:
            days_to_check =  sorted(control_df.columns, reverse=True)

        max_days = len(days_to_check)
        baseline_days = (max_days//7)*7-1

        for i, check_date in enumerate(days_to_check):
            if has_treatment: ds2, ds1 = pvt_trtm[check_date], pvt_ctrl[check_date]
            else: ds1 = pvt_ctrl[check_date]
            coef = len(ds1) ** .5
            if has_treatment:
                if is_count_metric: prec1, prec2 = np.std(ds1)*coef, np.std(ds2)*coef
                else: prec1, prec2 = np.std(ds1)/coef, np.std(ds2)/coef
                stdd = ((prec2 * prec2) + (prec1 * prec1)) ** .5
            else:
                if is_count_metric: prec1 = np.std(ds1)*coef
                else: prec1 = np.std(ds1)/coef
                stdd = prec1 * np.sqrt(2)

            if i == baseline_days and baseline_days>= 6:
                sev_days_std, sev_days_val = stdd, control_val[check_date]

            values[i+1] = [stdd, stdd*z, 100*stdd*z/control_val[check_date], testid_pct, metric_name, service, control_val[check_date] ]

        if baseline_days>=6:
            range_start, range_finish, step = baseline_days*((max_days+baseline_days)//baseline_days), 360, 1
            for new_days in range(max_days+1,range_finish, step):
                new_std = sev_days_std* (new_days/baseline_days)** .5
                new_val = sev_days_val* (new_days/baseline_days)
                values[new_days] = [new_std, new_std*z, 100*new_std*z/new_val, testid_pct, metric_name, service, new_val ]

        df = pd.DataFrame.from_dict(values, orient='index')
        df.columns = ['std', 'MDE_abs', 'MDE_diff', 'sample_size', 'metric_name','service', 'control_val']

        df.index.names = ['days']
        df.reset_index(inplace = True)
        df.query("days>0", inplace = True)
        raw_df = df.copy()

        for ssl in sample_sizes_list:
            if ssl in df['sample_size'].tolist():
                pass
            else:
                multi = np.round(ssl/testid_pct)
                new_df = raw_df.copy()
                new_df.days = (new_df.days/multi).round(2)
                new_df = new_df[new_df.days == new_df.days.round()]
                new_df.sample_size = ssl
                df = pd.concat([df, new_df], axis=0)
        return df.drop_duplicates()


    testids = tuple(set([str(i['control']) for i in metrics_to_run] + [str(i['treatment']) for i in metrics_to_run]))
    features1 = [str(i['feature1']) for i in metrics_to_run ]
    features2 = [str(i['feature2']) for i in metrics_to_run if i['feature2'] is not None]
    features = tuple(set(features1 + features2))
    slicehashes = tuple(set([str(i['slicehash']) for i in metrics_to_run]))

    yql_q = """
    USE hahn;
    PRAGMA SimpleColumns;
    PRAGMA yson.DisableStrict;
    -- PRAGMA yt.Pool = 'geoadv';
    PRAGMA yt.PoolTrees = "physical";
    PRAGMA yt.TentativePoolTrees = "cloud";
    PRAGMA yt.InferSchema = '1';

    $ystrd = DateTime::Format("%Y-%m-%d")(CurrentUtcDate()-Datetime::IntervalFromDays(1));
    $week_ago = DateTime::Format("%Y-%m-%d")(CurrentUtcDate()-Datetime::IntervalFromDays(8));

    DEFINE SUBQUERY $get_features($folder,$datestart,$ystrd) as
    $df1 = SELECT `slice`, `testid`, `feature`, `ts`,  `value`,
            ListMap(Yson::ConvertToList(buckets), ($x) -> {{ RETURN Yson::ConvertToDouble($x); }}) as buckets,
    FROM RANGE('home/abt/cofe/geo/features/'||$folder||'/daily',$datestart,$ystrd, `main/features/0`)
    WHERE testid in {}
    AND feature in {}
    and slice in {}
    ;
    $df2 = select a.* without buckets from $df1 as a flatten list by buckets as bucket_value;
    $df3 = select a.*, ROW_NUMBER() OVER w AS bucket_num from $df2 as a
    window w as (partition by feature, slice, testid, ts, value order by value);
    select * from $df3 where bucket_value is not null; end define;

    $new =
    select * from $get_features('processed_joins_for_geoadv_web',$week_ago,$ystrd)
    union all
    select * from $get_features('processed_joins_for_geoadv_app',$week_ago,$ystrd)
    union all
    select * from $get_features('maps_heavy',$week_ago,$ystrd)
    union all
    select * from $get_features('main',$week_ago,$ystrd)
    union all
    select * from $get_features('geosuccube',$week_ago, $ystrd)
    union all
    select * from $get_features('maps_metrika_mobile',$week_ago, $ystrd)
    ;

    $min_ts = select min(ts) from $new;

    $store_path = '//home/geoadv/ayunts/store_cofe_buckets';

    INSERT INTO $store_path WITH TRUNCATE
    select * from $new
    union all
    select * from $store_path
    where ts<$min_ts;

    COMMIT;
    select 1;
    """.format(testids,features,slicehashes)

    xx = run_yql_df(yql_q, title='cooking cofe buckets', client=YQL_CLIENT)

    # get_feature_df(
    #     control =  '48144', treatment = '48145',
    #     feature1 = 'geoadv_business_connection.has_goal_deep_use.count$',
    #     feature2 = None,
    #     metric_name = 'GDU геопродукта',
    #     service='desktop_maps',
    #     slicehash =  'd41d8cd98f00b204e9800998ecf8427e',
    #     testid_pct = 0.02 )

    # ### Расчет MDE для каждой метрики

    dfs =  [get_feature_df(
        control = metric['control'],
        treatment = metric['treatment'],
        feature1 = metric['feature1'],
        feature2 = metric['feature2'],
        metric_name = metric['metric_name'],
        service = metric['service'],
        slicehash =  metric['slicehash'],
        testid_pct = metric['testid_pct']
    ) for metric in metrics_to_run]

    df = pd.concat(dfs, axis=0)

    cluster = clusters.Hahn()
    cluster.write('home/geoadv/ayunts/store_fearture_sizes', df)
