import time
from datetime import datetime, timedelta

from absl import app, flags, logging
from absl.flags import FLAGS
import psycopg2
import pandas as pd
import pytz

from mgst_data.config import DB_HOST, DB_USER, DB_PASSWORD
from mgst_data.utils import upsert_data, query
from mgst_data.experiment_analysis.analysis import Cohort, lifecycle_cohorts
from mgst_data.experiment_analysis import analysis


flags.DEFINE_string('experiment_id', None, 'Experiment ID')
flags.DEFINE_string('experiment_name', None, 'Experiment Name')
flags.DEFINE_integer('experiment_version', None, 'Experiment Version')
flags.DEFINE_enum_class('cohort', None, Cohort, 'Cohort to analyze')
flags.DEFINE_boolean('skip_ended', True, 'Skip ended experiment')
flags.DEFINE_boolean('post_exp', False, 'Post Exp Analysis')
flags.DEFINE_boolean('print', False, 'Print Query Instead of Running')


def dump_result(con, df, exp, cohort):
    logging.info("%s", df)

    df = df.melt(id_vars=['iteration', 'experiment_group'],
                 var_name='metric_name',
                 value_name='metric_value'
                 )

    cur_date = pytz.utc.localize(datetime.utcnow()).astimezone(
        pytz.timezone("US/Pacific")).strftime("%Y-%m-%d")  # today

    data = []
    for row in df.itertuples():
        data.append([
            exp.experiment_id, exp.experiment_version, cohort.name, cur_date, row.iteration, row.experiment_group, row.metric_name, row.metric_value
        ])

    upsert_data(con,
                'mgst.experiment_results',
                ['experiment_id', 'experiment_version', 'cohort', 'date', 'iteration',
                 'experiment_group', 'metric_name', 'metric_value'],
                ['experiment_id', 'experiment_version', 'cohort'],
                data)


def analyze_cohort(con, exp, cohort, games, post):
    print('Cohort Analyasis', cohort)

    query = None
    if post:
        params = {
            'experiment_id': exp.experiment_id,
            'target_games': games
        }
        kwargs = {"base_query": 'sql/post_exp_analysis.sql'}
    else:
        params = {
            'experiment_id': exp.experiment_id,
            'experiment_version': exp.experiment_version,
            'target_games': games
        }
        kwargs = {}

    if 'mobile_web' in exp.latest_name:
        kwargs = analysis.extras_for_mweb()
    elif 'chat' in exp.latest_name:
        kwargs = analysis.extras_for_chat()
    elif 'clipfinity' in exp.latest_name:
        kwargs = analysis.load_extras('extra_clipfinity')
    elif exp.latest_name == 'ios_browse_updates':
        kwargs = analysis.load_extras('extra_no_autoplay')
    elif exp.latest_name == 'logged_out_notifications':
        kwargs = analysis.load_extras('extra_logged_out_adj')

    if cohort == Cohort.everyone:
        query = analysis.build_query(**kwargs)
    elif cohort in lifecycle_cohorts:
        query, extra_params = analysis.build_lifecycle_query(
            lifecycle_cohorts[cohort], **kwargs)
        params.update(extra_params)
    elif cohort == Cohort.mgst_viewers:
        query = analysis.build_mgst_viewers_query(**kwargs)
    elif cohort == Cohort.first_app_open:
        query = analysis.build_first_app_open_query(**kwargs)
    elif cohort == Cohort.discover:
        query = analysis.build_discover_user_query(**kwargs)
    else:
        raise Exception(f"unknown cohort {cohort}")

    query = con.cursor().mogrify(query, params)

    if FLAGS.print:
        print(str(query).replace('\\n', '\n'))
        return True

    df = pd.read_sql(query, con)

    if df.empty:
        print('No Result')
        return False

    dump_result(con, df, exp, cohort)
    return True


def analyze_one(con, exp, games, post=False):
    logging.info("Analyzing %s", exp)

    if FLAGS.cohort:
        analyze_cohort(con, exp, FLAGS.cohort, games, post)
    else:
        ok = analyze_cohort(con, exp, Cohort.everyone, games, post)

        if not ok:  # skip the rest of analysis if theres no result
            return

        for cohort in lifecycle_cohorts:
            analyze_cohort(con, exp, cohort, games, post)

        analyze_cohort(con, exp, Cohort.mgst_viewers, games, post)
        analyze_cohort(con, exp, Cohort.first_app_open, games, post)

        if 'discover' in exp.latest_name:
            analyze_cohort(con, exp, Cohort.discover, games, post)


def get_experiment_metadata(con):
    return pd.read_sql("""
        SELECT distinct experiment_id,
                        experiment_version,
                        date_trunc('day', start_time) as start_date,
                        date_trunc('day', end_time) as end_date,
                        last_value(experiment_name) over (partition BY experiment_id
                                                        ORDER BY experiment_version 
                                                        ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS latest_name,
                        max(experiment_version) over (partition BY experiment_id) AS latest_version
        FROM cubes.expo_experiment_metadata
        WHERE experiment_id in (SELECT distinct experiment_id FROM mgst.experiment_device)
        ORDER BY 1,
                 2
    """, con, parse_dates=['start_date', 'end_date'])


def main(_argv):
    start_time = time.time()

    con = psycopg2.connect(dbname='product', host=DB_HOST,
                           port='5439', user=DB_USER, password=DB_PASSWORD)

    games = query(con, """
        SELECT LOWER(game) AS game
        FROM cubes.hours_watched_daily AS hours_watched
        WHERE day >= CURRENT_DATE - interval '1 weeks'
            AND ((LOWER(game) LIKE '%free%fire%'
                OR LOWER(game) LIKE '%pubg%mobile%'
                OR LOWER(game) LIKE '%call of duty%mobile%'))
        AND viewer_country_code in ('br',
                                    'mx')
        AND video_product_type = 'live'
        GROUP BY 1 ORDER BY 1
        """, fetch=True)
    games = tuple(x[0] for x in games)

    logging.info("%s", games)

    experiments = get_experiment_metadata(con)
    if FLAGS.experiment_id:
        experiments = experiments[experiments.experiment_id ==
                                  FLAGS.experiment_id]

    if FLAGS.experiment_name:
        experiments = experiments[experiments.latest_name ==
                                  FLAGS.experiment_name]

    if FLAGS.experiment_version:
        experiments = experiments[experiments.experiment_version ==
                                  FLAGS.experiment_version]

    if FLAGS.post_exp:
        exps = experiments[['experiment_id', 'latest_name']].drop_duplicates()
        exps['experiment_version'] = 'final'
        print(exps)
        for exp in exps.itertuples():
            analyze_one(con, exp, games, post=True)
        return

    if FLAGS.skip_ended:
        # Ongoining or Ended in last 2 days
        experiments = experiments[experiments.end_date.isnull() |
                                  (experiments.end_date + timedelta(days=2) >= datetime.utcnow())]

    print(experiments)

    for exp in experiments.itertuples():
        analyze_one(con, exp, games)

    logging.info("Done %ds", time.time() - start_time)


if __name__ == '__main__':
    app.run(main)
