#-*- coding: UTF-8 -*-
from common import *

class get_uids_by_uuid(object):
    def __init__(self, uuid):
        self.uuid = uuid
    def __call__(self, recs):
        for rec in recs:
            if rec["uuid"] == self.uuid:
                yield Record(uid=rec["uid"])

def has_interest_by_search_stat(rec, fetch_params):
    has_black_list_interest = False
    has_interest = False
    source = None
    if rec["web_requests_stats"] == None:
        return has_interest, source, has_black_list_interest
    objects = 0
    requests = []
    for object_id in fetch_params["object_ids"]:
        if object_id in rec["web_requests_stats"]:
            objects += len(rec["web_requests_stats"][object_id]["requests"])
        if object_id in rec["video_requests_stats"]:
            objects += len(rec["video_requests_stats"][object_id]["requests"])

    reqs_about_event = 0
    queries = Counter()
    for object_id in rec["web_requests_stats"]:
        for elem in rec["web_requests_stats"][object_id]["requests"]:
            queries[elem.decode('utf8').lower()] += 1
    for object_id in rec["video_requests_stats"]:
        for elem in rec["video_requests_stats"][object_id]["requests"]:
            queries[elem.decode('utf8').lower()] += 1
    for query in fetch_params["queries"]:
        for elem, frequency in queries.items():
            if query in elem:
                reqs_about_event += frequency

    for query in fetch_params["queries_blacklist"]:
        for object_id in rec["web_requests_stats"]:
            has_black_list_interest = has_black_list_interest or sum([query in elem.decode('utf8').lower() for elem in rec["web_requests_stats"][object_id]["requests"]])
        for object_id in rec["video_requests_stats"]:
            has_black_list_interest = has_black_list_interest or sum([query in elem.decode('utf8').lower() for elem in rec["video_requests_stats"][object_id]["requests"]])

    if objects >= fetch_params["object_ids_threshold"] or reqs_about_event >= fetch_params["queries_threshold"]:
        has_interest = True
        source="search"

    return has_interest, source, has_black_list_interest

def has_interest_by_tv_online_stats(rec, fetch_params):
    has_black_list_interest = False
    has_interest = False
    source = None

    if rec["tv_online_stats"] == None:
        return has_interest, source, has_black_list_interest

    programs_tvt = 0
    channels_tvt = 0
    for object_id in rec["tv_online_stats"]:
        for program in fetch_params["tv_online_programs"]:
            if program in rec["tv_online_stats"][object_id]["computed_program"].decode('utf8').lower():
                programs_tvt += rec["tv_online_stats"][object_id]["tvt"]
        for program in fetch_params["tv_online_programs_black_list"]:
            if program in rec["tv_online_stats"][object_id]["computed_program"].decode('utf8').lower():
                has_black_list_interest = True
        for channel in fetch_params["tv_online_channels"]:
            if channel in rec["tv_online_stats"][object_id]["computed_channel"]:
                channels_tvt += rec["tv_online_stats"][object_id]["tvt"]

    if programs_tvt >= fetch_params["tv_online_programs_tvt_threshold"]:
        has_interest = True
        source="tv_online"
    if channels_tvt >= fetch_params["tv_online_channels_tvt_threshold"]:
        has_interest = True
        source="tv_online"

    return has_interest, source, has_black_list_interest

def has_interest_by_browser_stats(rec, fetch_params):
    has_black_list_interest = False
    has_interest = False
    source = None

    urls = 0
#    titles = 0
    for url in fetch_params["urls"]:
        for el in rec["urls"]:
            try:
                if url in el:
                    urls += rec["urls"][el]
            except:
                continue

#    for title in fetch_params["titles"]:
#        for el in rec["titles"]:
#            if title in el:
#                titles += rec["titles"][el]

#    for title in fetch_params["titles_blacklist"]:
#        for el in rec["titles"]:
#            if title in el:
#                has_black_list_interest = True

    if urls >= fetch_params["urls_threshold"]:
        has_interest = True
        source="browser"

    return has_interest, source, has_black_list_interest

class get_uids_for_push(object):
    def __init__(self, fetch_params):
        self.fetch_params = fetch_params
    def __call__(self, recs):

        for rec in recs:

            has_search_interest, search_source, has_search_black_list_interest = has_interest_by_search_stat(rec, self.fetch_params)
            has_tv_online_interest, tv_online_source, has_tv_online_black_list_interest = has_interest_by_tv_online_stats(rec, self.fetch_params)
            has_browser_interest, browser_source, has_browser_black_list_interest = has_interest_by_browser_stats(rec, self.fetch_params)

            if has_search_black_list_interest or has_tv_online_black_list_interest or has_browser_black_list_interest:
                continue

            need_yield = False
            source = None

            if has_search_interest:
                need_yield = True
                source = search_source
            elif has_tv_online_interest:
                need_yield = True
                source = tv_online_source
            elif has_browser_interest:
                need_yield = True
                source = browser_source

            if need_yield:
                yield Record(rec, source=source)

def has_intercepting_segments(user_segments, segments_of_interest):
    if segments_of_interest and user_segments:
        for segment in segments_of_interest:
            if segment in user_segments or int(segment) in user_segments:
                return True

class get_crypta_uids_for_push(object):
    def __init__(self, fetch_params):
        self.fetch_params = fetch_params
    def __call__(self, recs):
        for rec in recs:
            need_yield = False

            if has_intercepting_segments(rec["audience_segments"], self.fetch_params["audience_segments"]):
                need_yield = True
                source = "audience_segments"

            if has_intercepting_segments(rec["heuristic_internal"], self.fetch_params["heuristic_internal"]):
                need_yield = True
                source = "heuristic_internal"

            if has_intercepting_segments(rec["marketing_segments"], self.fetch_params["marketing_segments"]):
                need_yield = True
                source = "marketing_segments"

            if has_intercepting_segments(rec["heuristic_segments"], self.fetch_params["heuristic_segments"]):
                need_yield = True
                source = "heuristic_segments"

            if has_intercepting_segments(rec["heuristic_common"], self.fetch_params["heuristic_common"]):
                need_yield = True
                source = "heuristic_common"

            if has_intercepting_segments(rec["longterm_interests"], self.fetch_params["longterm_interests"]):
                need_yield = True
                source = "longterm_interests"

            if has_intercepting_segments(rec["lal_internal"], self.fetch_params["lal_internal"]):
                need_yield = True
                source = "lal_internal"

            if has_intercepting_segments(rec["lal_common"], self.fetch_params["lal_common"]):
                need_yield = True
                source = "lal_common"

            if need_yield:
                yield Record(uid=str(rec["yandexuid"]), source=source)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--object_ids', nargs='+', default=[])
    parser.add_argument('--object_ids_threshold', type=int, default=1e10)
    parser.add_argument('--queries', nargs='+', default=[])
    parser.add_argument('--queries_threshold', type=int, default=1e10)
    parser.add_argument('--queries_blacklist', nargs='+', default=[])
    parser.add_argument('--tv_online_programs', nargs='+', default=[])
    parser.add_argument('--tv_online_programs_tvt_threshold', type=int, default=1e10)
    parser.add_argument('--tv_online_programs_black_list', nargs='+', default=[])
    parser.add_argument('--tv_online_channels', nargs='+', default=[])
    parser.add_argument('--tv_online_channels_tvt_threshold', type=int, default=1e10)
    parser.add_argument('--urls', nargs='+', default=[])
    parser.add_argument('--urls_threshold', type=int, default=1e10)
    parser.add_argument('--titles', nargs='+', default=[])
    parser.add_argument('--titles_threshold', type=int, default=1e10)
    parser.add_argument('--titles_blacklist', nargs='+', default=[])

    parser.add_argument('--audience_segments', nargs='+', default=[])
    parser.add_argument('--heuristic_internal', nargs='+', default=[])
    parser.add_argument('--marketing_segments', nargs='+', default=[])
    parser.add_argument('--heuristic_segments', nargs='+', default=[])
    parser.add_argument('--heuristic_common', nargs='+', default=[])
    parser.add_argument('--longterm_interests', nargs='+', default=[])
    parser.add_argument('--lal_internal', nargs='+', default=[])
    parser.add_argument('--lal_common', nargs='+', default=[])

    parser.add_argument('--direct_uids', type=str, default="")

    parser.add_argument('--uuid', type=str, default="")
    parser.add_argument('--output_table', type=str, required=True)
    args = parser.parse_args()

    fetch_params = { "object_ids" : args.object_ids,
                     "object_ids_threshold" : args.object_ids_threshold,
                     "queries" : [query.decode('utf8').lower() for query in args.queries],
                     "queries_threshold" : args.queries_threshold,
                     "queries_blacklist" : [query.decode('utf8').lower() for query in args.queries_blacklist],
                     "tv_online_programs" : [program.decode('utf8') for program in args.tv_online_programs],
                     "tv_online_programs_tvt_threshold" : args.tv_online_programs_tvt_threshold,
                     "tv_online_programs_black_list" : [program.decode('utf8') for program in args.tv_online_programs_black_list],
                     "tv_online_channels" : args.tv_online_channels,
                     "tv_online_channels_tvt_threshold" : args.tv_online_channels_tvt_threshold,
                     "urls" : args.urls,
                     "urls_threshold" : args.urls_threshold,
                     "titles" : args.titles,
                     "titles_threshold" : args.titles_threshold,
                     "titles_blacklist" : args.titles_blacklist,
                     "audience_segments" : args.audience_segments,
                     "heuristic_internal" : args.heuristic_internal,
                     "marketing_segments" : args.marketing_segments,
                     "heuristic_segments" : args.heuristic_segments,
                     "lal_internal" : args.lal_internal,
                     "lal_common" : args.lal_common,
                     "heuristic_common" : args.heuristic_common,
                     "longterm_interests" : args.longterm_interests }

    cluster = clusters.yt.Hahn().env(parallel_operations_limit=10,
                                     yt_spec_defaults=dict(
                                         pool_trees=["physical"],
                                         tentative_pool_trees=["cloud"]
                                     ),
                                     templates=dict(
                                         tmp_root='//tmp',
                                         title='GetUidsForPushes'
                                     ))

    need_crypta_push = False
    if args.audience_segments or args.heuristic_internal or args.marketing_segments or args.heuristic_segments\
            or args.lal_internal or args.lal_common or args.heuristic_common or args.longterm_interests:
        need_crypta_push = True

    if need_crypta_push:
        job = cluster.job()
        job.table(CRYPTA_PROFILES).map(get_crypta_uids_for_push(fetch_params)) \
           .sort('uid') \
           .put(args.output_table + "_crypta_uids")
        job.run()

        prepare_uids_to_push(cluster, args.output_table + "_crypta_uids", args.output_table + "_crypta_uids_prepared_for_push", [], False)

    job = cluster.job()

    if args.direct_uids:
        job.table(args.direct_uids).project('install_id') \
                                   .sort('install_id') \
                                   .put(args.output_table)
    else:
        uids_with_info = job.table(STATS_PREFIX + AGGREGATED_STATS_SUFFIX) \
                            .filter(sf.defined('in_sup_base'),
                                    sf.equals('in_sup_base', True)) \
                            .map(get_uids_for_push(fetch_params))
        if args.uuid:
            uids_viewed = job.table(VIEWED_CONTENT_TABLE).map(get_uids_by_uuid(args.uuid))
            uids_with_info = uids_with_info.join(uids_viewed, by='uid', type='left_only')

        if need_crypta_push:
            crypta_uids = job.table(args.output_table + "_crypta_uids_prepared_for_push") \
                             .filter(sf.defined('in_sup_base'),
                                     sf.equals('in_sup_base', True))

            uids_with_info = job.concat(uids_with_info, crypta_uids)

        uids_with_info.sort('uid') \
                  .put(args.output_table + "_with_info")

        uids_with_info.groupby('install_id') \
                      .aggregate(count=na.count()) \
                      .project('install_id') \
                      .sort('install_id') \
                      .put(args.output_table)

    job.run()

if __name__ == '__main__':
    main()
