import argparse
from datetime import datetime
import logging
logging.basicConfig(level=logging.INFO)

import yt.yson
import yt.wrapper as yt

from lib_monitoring import write_monitoring_to_graphite
from ads.libs.yql import run_yql_query
from bm_mr.resources import home


HAHN_DEFAULT = 'hahn'
SD_SHOWS_DIR = '//home/broadmatching/processed_logs/shows_by_sd_st_days'
BS_CHEVENT_LOG_BASE_PATH = '//logs/bs-chevent-log/1h'
SAMPLE_PATH = '//logs/bs-chevent-log/1h/2019-01-23T00:00:00'
SHOWS_WINDOW = 7
KEEP_DAY_SHOWS_TABLES = 28


def process_logs_and_get_newest_shows(days, args):
    hours_in_days = 24 * (days + 1)  # Add one for incomplete day
    two_month_tables = sorted([
        '{}/{}'.format(BS_CHEVENT_LOG_BASE_PATH, x) for x in yt.list(BS_CHEVENT_LOG_BASE_PATH)])[-hours_in_days:]

    day_prefixes = {x[:len(SAMPLE_PATH) - len('T00:00:00')] for x in two_month_tables}
    sorted_prefixes = sorted([x for x in day_prefixes if 24 == len(
        [y for y in two_month_tables if y.startswith(x)]
    )])

    sd_shows_tables = []
    yt.create('map_node', SD_SHOWS_DIR, recursive=True, ignore_existing=True)
    for prefix in sorted_prefixes[-days:]:
        sd_shows_table_day = '{}/{}'.format(SD_SHOWS_DIR, prefix.replace('/', '_'))
        sd_shows_tables.append(sd_shows_table_day)
        day_tables = [x for x in two_month_tables if x.startswith(prefix)]
        if not yt.exists(sd_shows_table_day):
            query = '''
                pragma SimpleColumns;
                pragma yt.Pool = "broadmatching";

                insert into `{sd_shows_table_day}` with truncate
                select EventTime,
                    SimDistance,
                    SelectType,
                    if ((ParentBannerID ?? 0) > 0, ParentBannerID, BannerID) as BannerID
                from (
                    select cast(eventtime as int64) as EventTime,
                    cast(simdistance as uint32) as SimDistance,
                    cast(selecttype as int32) as SelectType,
                    cast(parentbannerid as int64) as ParentBannerID,
                    cast(bannerid as int64) as BannerID
                    from concat({tables})
                    where fraudbits == '0' and placeid == '542'
                )
                order by BannerID, EventTime, SelectType, SimDistance;
            '''.format(
                sd_shows_table_day=sd_shows_table_day,
                tables=','.join(['`{}`'.format(x) for x in day_tables]))
            run_yql_query(query, db=args.yt_cluster, title='monitoring: process_logs')

    all_sd_shows_tables = ['{}/{}'.format(SD_SHOWS_DIR, x) for x in yt.list(SD_SHOWS_DIR)]
    assert days < KEEP_DAY_SHOWS_TABLES
    if len(all_sd_shows_tables) > KEEP_DAY_SHOWS_TABLES:
        tables_to_remove = sorted(all_sd_shows_tables)[:len(all_sd_shows_tables) - KEEP_DAY_SHOWS_TABLES]
        for table_to_remove in tables_to_remove:
            yt.remove(table_to_remove)

    return sd_shows_tables


def sd_shows_lag(args):
    sd_shows_tables = process_logs_and_get_newest_shows(SHOWS_WINDOW, args)
    query = '''
    pragma SimpleColumns;
    pragma yt.Pool = "broadmatching";

    select timeslot, min(Age) as minAge, SimDistance, SelectType
    from (
        select shows.*, shows.EventTime - banners.`timestamp` as Age
        from concat({sd_shows_tables_yql}) as shows
        inner join (
            select BannerID, min_timestamp as `timestamp`
            from (
                select pid, min(`timestamp`) as min_timestamp
                from `{banner_creation_time_slow}`
                group by pid
            ) as mints
            inner join `{banner_creation_time_slow}`
            using (pid)
            where BannerID != 0
        ) as banners
        using (BannerID)
    )
    group by 3*3600 * ((EventTime-3600) / (3*3600)) as timeslot, SimDistance, SelectType; -- -1 hour to get full 3 hours
    '''.format(
        sd_shows_tables_yql=','.join(['`{}`'.format(x) for x in sd_shows_tables]),
        banner_creation_time_slow=home.broadmatching.bmyt.direct_export.banners_creation_time_slow.read(),
    )
    request = run_yql_query(query, db=args.yt_cluster, title='monitoring: sd_shows_lag')
    if not request.table.fetch_full_data():
        raise Exception("YQL fetch_full_data failed")
    for row in request.table.rows:
        timeslot, min_age, sim_distance, select_type = row
        date = datetime.fromtimestamp(int(timeslot))
        name = 'bm.bmyt_sd_show_lag.by_st_sd.{}.{}'.format('ST=' + str(select_type), 'SD=' + str(sim_distance))
        lag = int(min_age) / 3600
        write_monitoring_to_graphite(
            name, lag, str(date),
            graphite_frequency='one_hour',
            graphite_namespace='bs.sandbox_graphite_sender'
        )
        # ToDo: Remove after enough statistics would be accumulated in by_st_sd.ST=5
        if select_type == 5:
            name = 'bm.bmyt_sd_show_lag.by_sd.{}'.format('SD=' + str(sim_distance))
            lag = int(min_age) / 3600
            write_monitoring_to_graphite(
                name, lag, str(date),
                graphite_frequency='one_hour',
                graphite_namespace='bs.sandbox_graphite_sender'
            )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--yt-cluster', default=HAHN_DEFAULT)
    parser.add_argument('--yt-pool', default='broadmatching')
    parser.add_argument('--task', default=[], nargs='+')
    args = parser.parse_args()
    for task in args.task:
        if task == 'sd_shows_lag':
            sd_shows_lag(args)


if __name__ == '__main__':
    main()
