#-*- coding: UTF-8 -*-
from common import *


class get_uids_with_interest:
    def __init__(self,
                 audience_segments_ids,
                 longterm_interests_ids):
        self.audience_segments_ids = audience_segments_ids
        self.longterm_interests_ids = longterm_interests_ids

    def __call__(self, recs):
        for rec in recs:
            has_interest = False
            if rec['audience_segments'] != None:
                for segment in rec.get('audience_segments', []):
                    if segment in self.audience_segments_ids:
                        has_interest = True
            if rec['longterm_interests'] != None:
                for interest in rec.get('longterm_interests', []):
                    if interest in self.longterm_interests_ids:
                        has_interest = True
            if has_interest:
                yield rec

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--audience_segments_ids', nargs='+', default=[])
    parser.add_argument('--longterm_interests_ids', nargs='+', default=[])
    parser.add_argument('--regions', nargs='+', default=[])
    parser.add_argument('--browsers', nargs='+', default=[])
    parser.add_argument('--output_table', type=str, required=True)
    parser.add_argument('--input_crypta_table', type=str, required=True)
    parser.add_argument('--input_yandexuid_info_table', type=str, required=True)
    parser.add_argument('--sample_size', type=float, required=True)
    args = parser.parse_args()

    regions = [int(elem) for elem in args.regions]
    audience_segments_ids = [int(elem) for elem in args.audience_segments_ids]
    longterm_interests_ids = [int(elem) for elem in args.longterm_interests_ids]
    print "regions :", regions
    print "audience_segments :", audience_segments_ids
    print "longterm_interests :", longterm_interests_ids
    cluster = clusters.yt.Hahn().env(parallel_operations_limit=10,
                                     yt_spec_defaults=dict(
                                         pool_trees=["physical"],
                                         tentative_pool_trees=["cloud"]
                                     ),
                                     templates=dict(
                                         tmp_root='//tmp',
                                         title='GetUidsForPushes'
                                     ))

    job = cluster.job()
    uids_with_interest = job.table(args.input_crypta_table) \
                            .map(get_uids_with_interest(audience_segments_ids,
                                                        longterm_interests_ids))

    uids_with_region = job.table(args.input_yandexuid_info_table) \
                          .filter(sf.one_of("main_region", regions),
                                  sf.one_of("browser", args.browsers))

    uids = uids_with_interest.project(ne.all(), uid=ne.custom(lambda x: str(x), "yandexuid")) \
                             .join(uids_with_region, type='inner', by_left='uid', by_right='id')
    uids.put(args.output_table + '_with_info')
    job.run()

    prepare_uids_to_push(cluster, args.output_table + '_with_info', args.output_table, [])
    if cluster.driver.client.get_attribute(args.output_table, 'row_count', 0) == 0:
        raise Exception("Empty Table")

    job = cluster.job()
    job.table(args.output_table) \
       .project("install_id") \
       .random(fraction=min(1, args.sample_size / cluster.driver.client.get_attribute(args.output_table, 'row_count', 0))) \
       .sort("install_id") \
       .put(args.output_table)
    job.run()


if __name__ == '__main__':
    main()
