import argparse
# import time

import datetime

from projects.efficiency_metrics.manager import Manager
from projects.efficiency_metrics.project_config import get_project_cluster


if __name__ == '__main__':

    # parser = argparse.ArgumentParser()
    # parser.add_argument('--yt-proxy', type=str, default='Hahn')
    # parser.add_argument('--yt-path-to', required=True, type=str)
    # parser.add_argument('--from-date-hour', required=True, type=str)
    # parser.add_argument('--to-date-hour', required=True, type=str)
    # parser.add_argument('--city', required=True, type=str)
    # parser.add_argument(
    #     '--steps',
    #     required=True,
    #     nargs='+',
    #     choices=[
    #         'prepare_data',
    #         # 'calculate_metrics_by_tz'
    #         # 'calculate_metrics_by_performer',
    #     ],
    # )
    #
    # args = parser.parse_args()

    # assert datetime.datetime.strptime(args.from_date_hour, '%Y-%m-%d')
    # assert datetime.datetime.strptime(args.to_date_hour, '%Y-%m-%d')
    #

    # nile_pipeline_params = {
    #     'period_params' : {
    #         'begin_dttm': args.from_date_hour,
    #         'end_dttm': args.to_date_hour
    #     },
    #     'supply_type_yt_path': '//home/taxi_ml/dev/drivers/delivery/tmp_tmp/alena_lukina_cube_{}_{}',
    #     'sessions_yt_path': '//home/taxi_ml/dev/drivers/delivery/tmp_tmp/sessions_{}_{}',
    #     'sessions_by_performer_yt_path': '//home/taxi_ml/dev/drivers/delivery/tmp_tmp/sessions_by_performer_{}_{}',
    #     'raw_orders_yt_path': '//home/taxi_ml/dev/drivers/delivery/tmp_tmp/raw_orders_{}_{}',
    #     'claims_yt_path': '//home/taxi_ml/dev/drivers/delivery/tmp_tmp/claims_{}_{}',
    #     'sessions_reqs_yt_path': '//home/taxi_ml/dev/drivers/delivery/tmp_tmp/sessions_reqs_{}_{}',
    #     'tmp_burnt_orders_yt_path': '//home/taxi_ml/dev/drivers/delivery/tmp_tmp/tmp_burnt_orders_{}_{}',
    #     'burnt_orders_yt_path': '//home/taxi_ml/dev/drivers/delivery/tmp_tmp/final_burnt_orders_{}_{}',
    #     'final_table_yt_path': '//home/taxi_ml/dev/drivers/delivery/tmp_tmp/final_table_{}_{}'
    # }

    # manager = Manager(
    #     yt_proxy=args.yt_proxy,
    #     nile_pipeline_params=nile_pipeline_params,
    # )
    # # TODO: по unique_driver_id - целевой сапплай
    # # TODO: b2b/not b2b flag
    #
    # for step in args.steps:
    #     print(step)
    #     getattr(manager, step)()

    cluster = get_project_cluster()

    from projects.data_sources.data_context.raw_services_logs import \
        DataContext as RawServicesDataContext
    #
    import time
    job = cluster.job('Couriers collection' + str(time.time()))
    job = job.env(bytes_decode_mode='strict', yt_spec_defaults={'max_failed_job_count': 1000})

    log_couriers = job.table('//home/taxi_ml/dev/drivers/delivery/priority/couriers_15_02')

    (
        RawServicesDataContext(
            job,
            datetime.datetime.strptime(
                '2021-04-01', '%Y-%m-%d')
            ,
            datetime.datetime.strptime(
                '2021-04-13', '%Y-%m-%d'
            ),
        ).get_taxi_driver_scoring()
        # .join(
        #     log_couriers, by='dbid_uuid', type='inner'
        # )
            .put(
            '//home/taxi-delivery/analytics/dev/priority/check_changing_13_04'
        )
    )

    job.run()
    from nile.api.v1 import Record

    def _mapper(records):
        for record in records:
            yield Record(
                utc_date=record.utc_timestamp[:10],
                zone=record.zone,
                # bonus=record.get('bonuses', {}).get("courier-supply-bonus", {}).get('bonus'),
                # express-supply-bonusexpress-supply-bonus
                express_bonus=record.get('bonuses', {}).get("express-supply-bonus",
                                                    {}).get('bonus'),
                courier_bonus=record.get('bonuses', {}).get("courier-supply-bonus",
                                                    {}).get('bonus')

            )


    def _reducer(groups):
        for key, records in groups:

            res = {}
            for record in records:
                if '{}_{}'.format(record.get('zone', ''), int(record.courier_bonus)) in res:
                    res['{}_{}'.format(record.get('zone', ''), int(record.courier_bonus))] += 1
                else:
                    res['{}_{}'.format(record.get('zone', ''),
                                       int(record.courier_bonus))] = 1
            yield Record(
                key,
                **res

            )

    from qb2.api.v1 import filters as qf

    job = cluster.job('Couriers collection' + str(time.time()))
    job = job.env(bytes_decode_mode='strict', yt_spec_defaults={'max_failed_job_count': 1000})


    (
        job.table('//home/taxi-delivery/analytics/dev/priority/check_changing_13_04').map(
            _mapper
        )
        .filter(qf.defined('courier_bonus')).groupby('utc_date').reduce(_reducer).sort(
            'utc_date'
        )
        .put(
            '//home/taxi-delivery/analytics/dev/priority/final_changing_13_04'
        )
    )

    job.run()