#!/usr/bin/python
import simplejson
import os
import datetime
from yt.wrapper import YtClient
from qb2.api.v1.typing import (
    Optional,
    String, Dict, Yson,
    List
)
from nile.api.v1 import (
    clusters,
    extractors as ne,
    aggregators as na,
    filters as nf
)
from qb2.api.v1 import (
    filters as qf,
    extractors as qe
)

yt_client = None
cluster = None

PROCESS_BY_HOUR = False
OUTPUT_HOME_DIR = os.environ.get('OUTPUT_HOME_DIR', '//home/mpfs-stat/notifier/rtx')

NOTIFIER_EVENTS_LOG = 'ydisk-notifier-events-log'
METRIKA_MOBILE_LOG = 'metrika-mobile-log'

DISK_IOS_CLIENT_ID = 'f6ba541f65084e688107879721631a65'
DISK_ANDROID_CLIENT_ID = 'ff90127313fd4378873d6b57914e8e11'

OAUTH_JOIN_DAYS = int(os.environ.get('OAUTH_JOIN_DAYS', '1'))
NOTIFIER_JOIN_DAYS = int(os.environ.get('NOTIFIER_JOIN_DAYS', '7'))

_DICT_STR_ANY = Optional[Dict[String, Yson]]
_LIST_OF_STR = Optional[List[String]]


def log(message):
    print message


def parse_datetime_from_path(path):
    date_str = path.split('/')[-1]
    if 'T' in path:
        return datetime.datetime.strptime(date_str, '%Y-%m-%dT%H:%M:%S')
    elif ':' not in path:
        return datetime.datetime.strptime(date_str, '%Y-%m-%d')
    else:
        return datetime.datetime.fromtimestamp(int(date_str))


def format_date_for_yt(dt):
    return dt.strftime('%Y-%m-%d')


def format_datetime_for_yt(dt):
    return dt.strftime('%Y-%m-%dT%H:%M:%S')


def nile_path_for_list(paths):
    return "{" + ",".join(paths) + "}"


def extract_win_type(metrika_event_name):
    if metrika_event_name in ['bright_push_received', 'push_showed']:
        return 'show'
    elif metrika_event_name in ['bright_push_tapped', 'push_tapped']:
        return 'click'
    else:
        return 'unknown'


def extract_platform_from_tags(xiva_tags_str):
    if 'android' in xiva_tags_str:
        return 'Android'
    elif 'ios_bright' in xiva_tags_str:
        return 'iOS'
    else:
        return 'unknown'


def parse_user_features(user_features_str):
    if not user_features_str:
        return None
    return simplejson.loads(user_features_str.replace("\\\"", "\""))


def parse_alternatives(alternatives_str):
    if not alternatives_str:
        return None
    return alternatives_str.split(",")


def prepare_notifier_events_log(job, path):
    return job\
        .table(path, weak_schema=dict(
            logtime=str,
            unixtime=int,
            uid=int,

            group_key=str,
            record_type=str,
            template=str,
            metadata=str,

            xiva_transit_id=str,
            xiva_tags=str,

            channel=str,
            lifecycle_stage=str,

            template_reason_source=str,
            template_reason_alternatives=str,
            template_reason_rtx_reqid=str,
            template_reason_reqid=str,
            template_reason_user_features=str))\
        .filter(qf.and_(qf.equals('lifecycle_stage', 'push_sent'),
                         qf.one_of('channel', ('notification_mobile_v2', 'notification_mobile_v2_bright'))
                        )
                )\
        .project('logtime',
                 'uid',
                 'group_key',
                 'record_type',
                 'template',
                 'metadata',
                 'xiva_transit_id',
                 'xiva_tags',

                  passport_uid='uid',

                  timestamp='unixtime',

                  reqid='template_reason_reqid',
                  rtx_reqid='template_reason_rtx_reqid',
                  source='template_reason_source',

                  alternatives=ne.custom(lambda template_reason_alternatives: parse_alternatives(template_reason_alternatives)).add_hints(type=_LIST_OF_STR),
                  user_features=ne.custom(lambda template_reason_user_features: parse_user_features(template_reason_user_features)).add_hints(type=_DICT_STR_ANY),

                  group_and_record_type=ne.custom(lambda group_key, record_type: (group_key + "_" + record_type) if group_key and record_type else None).add_hints(type=str),
                  target_platform=ne.custom(lambda xiva_tags: extract_platform_from_tags(xiva_tags)).add_hints(type=str),
                  project=ne.const('disk')
                 )


def prepare_metrika_events_log(job, path):
    return job.table(path) \
        .qb2(
        log='metrika-mobile-log',
        fields=[
            'device_id',
            'app_platform',

            'event_name',
            'raw_event_value',
            'event_timestamp_msk',
            qe.custom('event_value_json', lambda raw_event_value: simplejson.loads(raw_event_value)).hide(),
            qe.custom('group_and_record_type', lambda event_value_json: event_value_json.get('notification_mobile_v2', None)).add_hints(type=str),
            qe.custom('transit_id', lambda event_value_json: event_value_json.get('transit_id', None)).rename('xiva_transit_id').add_hints(type=str),
        ],
        filters=[
            qf.equals("api_key_str", "18895"),
            qf.one_of('event_name', [
                'bright_push_received',
                'bright_push_tapped',
                'push_tapped',
                'push_showed'
            ]),
            qf.contains('raw_event_value', 'notification_mobile_v2')
        ]
    ).project(ne.all(exclude=['event_name']), metrika_event_name='event_name')


def prepare_oauth_logs(job, path):
    return job.table(path)\
        .project('mode', 'status', 'client_id', 'uid', 'device_id', 'unixtime')\
        .filter(
            qf.and_(
                qf.equals('mode', 'verify_token'),
                qf.equals('status', 'OK'),
                qf.one_of('client_id', [DISK_IOS_CLIENT_ID, DISK_ANDROID_CLIENT_ID])
            )
        ).project('uid', 'device_id', 'unixtime').groupby('device_id').aggregate(
            uid=na.last('uid', by='unixtime'),
        )


def prepare_oauth_logs_to_join(job, dt):
    return prepare_oauth_logs(job, nile_path_for_list(get_oauth_logs_paths(OAUTH_JOIN_DAYS, dt)))


def process_timestamp(dt, by_hour):
    log('Processing ' + str(dt) + ' by_hour=' + str(by_hour))
    job = cluster.job()

    source_notifier_parths = source_notifier_logs(dt, by_hour)
    notifier_pushes_log_current = prepare_notifier_events_log(job, nile_path_for_list(source_notifier_parths))
    log('Current notifier events log: ' + str(source_notifier_parths))

    notifier_log_paths_to_join = notifier_logs_to_join(dt, by_hour)
    log('Notifier logs to join: ' + ", ".join(notifier_log_paths_to_join))
    notifier_pushes_log = prepare_notifier_events_log(job, nile_path_for_list(notifier_log_paths_to_join))

    source_metrika_paths = source_metrika_logs(dt, by_hour)
    log('Current metrika logs: ' + ", ".join(source_metrika_paths))
    all_mobile_push_events = prepare_metrika_events_log(job, nile_path_for_list(source_metrika_paths))

    uid_by_device_id = prepare_oauth_logs_to_join(job, dt)
    all_mobile_push_events = all_mobile_push_events.join(uid_by_device_id, by='device_id', type='inner')

    with_transit_id = all_mobile_push_events.filter(
        qf.contains('raw_event_value', 'transit_id')
    )

    without_transit_id = all_mobile_push_events.filter(
        qf.not_(qf.contains('raw_event_value', 'transit_id'))
    ).project(ne.all(exclude=['xiva_transit_id']))

    def parse_ts(date_str):
        return datetime.datetime.fromtimestamp(int(date_str))

    with_transit_id_joined = with_transit_id.join(notifier_pushes_log, by=('xiva_transit_id', 'uid', 'group_and_record_type'), type='inner')
    without_transit_id_joined = without_transit_id.\
        join(notifier_pushes_log, by=('uid', 'group_and_record_type'), type='inner') \
        .filter(nf.custom(lambda event_timestamp_msk, timestamp: parse_ts(timestamp).date() == parse_ts(event_timestamp_msk).date()
                                                                 and int(timestamp) < event_timestamp_msk))

    rtx_wins_log = job\
        .concat(with_transit_id_joined, without_transit_id_joined) \
        .project(
            'reqid',
            'rtx_reqid',
            timestamp='event_timestamp_msk',
            device_id='device_id',
            platform='app_platform',
            project=ne.const('disk'),
            win_type=ne.custom(lambda metrika_event_name: extract_win_type(metrika_event_name)).add_hints(type=str),
        )

    all_mobile_push_events.put(output_path_for_datetime(dt, 'parsed_metrika_events', by_hour))
    with_transit_id.put(output_path_for_datetime(dt, 'parsed_metrika_events_with_transit_id', by_hour))
    without_transit_id.put(output_path_for_datetime(dt, 'parsed_metrika_events_without_transit_id', by_hour))

    with_transit_id_joined.put(output_path_for_datetime(dt, 'joined_by_transit_id', by_hour))
    without_transit_id_joined.put(output_path_for_datetime(dt, 'joined_by_uid_and_type', by_hour))
    notifier_pushes_log_current.sort('timestamp').put(output_path_for_datetime(dt, 'sends', by_hour))
    rtx_wins_log.put(output_path_for_datetime(dt, 'wins', by_hour))

    job.run()
    # print job.flow_graph


def find_timestamps_to_process(by_hour):
    first_allowed_date = datetime.date(year=2019, month=1, day=30)

    result = []

    current_interval_start = datetime.datetime.now().replace(second=0, minute=0, microsecond=0)
    if not by_hour:
        current_interval_start = current_interval_start.replace(hour=0)

    for i in range(10):
        interval_start = current_interval_start - datetime.timedelta(days=i)
        if interval_start.date() < first_allowed_date:
            continue

        if not yt_client.exists(output_path_for_datetime(interval_start, "wins", by_hour)) or not yt_client.exists(output_path_for_datetime(interval_start, "sends", by_hour)):
            if has_source_logs(interval_start, by_hour):
                result.append(interval_start)

    return sorted(result)


def has_source_logs(dt, by_hour):
    source_logs_paths = source_notifier_logs(dt, by_hour) + source_metrika_logs(dt, by_hour)

    for path in source_logs_paths:
        if not yt_client.exists(path):
            log('No no source log ' + path + '. Skip processing of ' + str(dt))
            return False

    return True


def source_notifier_logs(dt, by_hour):
    if by_hour:
        return [get_1h_log_path(NOTIFIER_EVENTS_LOG, dt)]
    else:
        return [get_1d_log_path(NOTIFIER_EVENTS_LOG, dt)]


def notifier_logs_to_join(dt, by_hour):
    if by_hour:
        return get_all_same_day_logs(NOTIFIER_EVENTS_LOG, dt) + get_prev_day_logs(NOTIFIER_EVENTS_LOG, NOTIFIER_JOIN_DAYS, dt)
    else:
        return get_prev_day_logs(NOTIFIER_EVENTS_LOG, NOTIFIER_JOIN_DAYS, dt)


def source_metrika_logs(dt, by_hour):
    if by_hour:
        return [get_30min_log_path(METRIKA_MOBILE_LOG, dt), get_30min_log_path(METRIKA_MOBILE_LOG, dt + datetime.timedelta(minutes=30))]
    else:
        return [get_1d_log_path(METRIKA_MOBILE_LOG, dt)]


def get_1d_log_path(log_name, dt):
    return get_log_home_path(log_name) + '/1d/' + format_date_for_yt(dt)


def get_1h_log_path(log_name, dt):
    return get_log_home_path(log_name) + '/1h/' + format_datetime_for_yt(dt)


def get_30min_log_path(log_name, dt):
    return get_log_home_path(log_name) + '/30min/' + format_datetime_for_yt(dt)


def get_all_same_day_logs(log_name, dt):
    logs_home = get_log_home_path(log_name)
    if yt_client.exists(logs_home + '/30min'):
        logs_home += '/30min'
    elif yt_client.exists(logs_home + '/1h'):
        logs_home += '/1h'
    else:
        raise ValueError('log has no inday aggregation: ' + log_name)

    result = []
    for path in yt_client.list(logs_home, absolute=True):
        log_dt = parse_datetime_from_path(path)
        if log_dt.date() == dt.date():
            result.append(path)
    return result


def get_oauth_logs_paths(days_count, current_dt):
    result = []

    delta = datetime.timedelta(days=1)
    dt = current_dt.date()

    def get_oauth_log_path(dt):
        return '//statbox/oauth-log/' + format_date_for_yt(dt)

    if yt_client.exists(get_oauth_log_path(dt)):
        result.append(get_oauth_log_path(dt))

    dt -= delta

    for i in range(days_count):
        result.append(get_oauth_log_path(dt))
        dt = dt - delta

    return result


def get_prev_day_logs(log_name, days_count, current_dt, add_current=True):
    result = []

    delta = datetime.timedelta(days=1)
    dt = current_dt.date()

    # add current day if already fishshed
    if add_current and yt_client.exists(get_1d_log_path(log_name, dt)):
        result.append(get_1d_log_path(log_name, dt))

    dt -= delta

    for i in range(days_count):
        result.append(get_1d_log_path(log_name, dt))
        dt = dt - delta

    return result


def get_log_home_path(log_name):
    return '//home/logfeller/logs/' + log_name


def output_path_for_datetime(dt, table_name, by_hour):
    if by_hour:
        return OUTPUT_HOME_DIR + "/" + format_datetime_for_yt(dt) + "/" + table_name
    else:
        return OUTPUT_HOME_DIR + "/" + format_date_for_yt(dt) + "/" + table_name


def main():
    global yt_client, cluster
    yt_token = os.environ['YT_TOKEN']
    yql_token = os.environ['YQL_TOKEN']
    yt_client = YtClient(proxy="hahn", config={"token": yt_token})
    cluster = clusters.Hahn(token=yt_token, yql_token=yql_token)

    if os.environ.get('PROCESS_DATETIME'):
        timestamps_to_process = [parse_datetime_from_path(date_str) for date_str in os.environ.get('PROCESS_DATETIME').split(",")]
    else:
        timestamps_to_process = find_timestamps_to_process(PROCESS_BY_HOUR)

    for dt in timestamps_to_process:
        process_timestamp(dt, PROCESS_BY_HOUR)


if __name__ == '__main__':
    main()
