import argparse
from collections import Counter
import time
import re

import datetime
from nile.api.v1 import aggregators as na
from nile.api.v1 import extractors as ne
from nile.api.v1 import filters as nf
from nile.api.v1 import Record
from qb2.api.v1 import filters as qf
import numpy as np


from projects.batching.batching_custom_params import (
    VKUSVILL_CORP_CLIENT_ID, VV_ADDRESSES, VV_ADDRESSES_IDS,
    YARCHE_CORP_CLIENT_ID, YARCHE_ADDRESSES
)
from projects.data_sources.data_context.cargo import \
    DataContext as CargoLogsDataContext
from projects.data_sources.data_context.raworders_dmorders_sessions import \
    DataContext as OrdersSessionsLogsDataContext

from projects.data_sources.data_context.raw_services_logs import \
    DataContext as RawServicesLogsDataContext
from projects.efficiency_metrics.project_config import get_project_cluster
from projects.batching.nile_blocks.corps import calc_a_density_reducer



if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--yt-proxy', type=str, default='Hahn')
    # parser.add_argument('--yt-path-dir-to', type=str,
    #                     default='//home/taxi-delivery/analytics/dev/{}')
    # parser.add_argument('--start-date', type=str)
    # parser.add_argument('--finish-date', type=str)

    args = parser.parse_args()

    cluster = get_project_cluster()

    # yt_path_dir_to = args.yt_path_dir_to
    # start_date = args.start_date
    # finish_date = args.finish_date

    def calc_tr(times, times_point_c, type):
        tr = None
        if type == 'a1-a2-b1-b2':
            tr = (
                    times_point_c.get('a1_c') + times_point_c.get('c_a2') +
                    times.get('a2_b1') + times.get('b1_b2')
            )
        if type == 'a1-a2-b2-b1':
            tr = (
                    times_point_c.get('a1_c') + times_point_c.get('c_a2') +
                    times.get('a2_b2') + times.get('b2_b1')
            )
        return tr


    def get_last_record_reducer(groups):
        for key, records in groups:
            new_rec = None
            for record in records:
                new_rec = record
            yield Record(new_rec)

    # job = cluster.job('Couriers collection' + str(time.time()))
    # job = job.env(
    #     bytes_decode_mode='strict',
    #     yt_spec_defaults={'max_failed_job_count': 1000}
    # )
    #
    # job.table(
    #     '//home/taxi_ml/dev/drivers/delivery/2020-11-30_moscow/task_2/orders_batches_filtered_severe'
    # ).project(
    #     ne.all(),
    #     transporting_wo_batch=ne.custom(
    #         lambda x, y: (
    #             x.get('plan_travel_time_min') + y.get('plan_travel_time_min')
    #         ) * 60,
    #         'order_id_1_info', 'order_id_2_info'
    #     ),
    #     driving_wo_batch=ne.custom(
    #         lambda x, y: (
    #                 x.get('eta_duration_sec') + y.get('eta_duration_sec')
    #         ),
    #         'order_id_1_info', 'order_id_2_info'
    #     ),
    #     lengths_in_km_1=ne.custom(
    #         lambda x: x.get('plan_travel_time_min') * 1. / 60 * 24,
    #         'order_id_1_info'
    #     ),
    #     lengths_in_km_2=ne.custom(
    #         lambda x: x.get('plan_travel_time_min') * 1. / 60 * 24,
    #         'order_id_2_info'
    #     ),
    #     driving_w_batch=ne.custom(lambda x: x.get('eta_duration_sec'), 'order_id_1_info'),
    #     driving_economy=ne.custom(lambda x: x.get('eta_duration_sec'), 'order_id_2_info'),
    #     transporting_w_batch=ne.custom(
    #         lambda x, y, z: calc_tr(x, y, z), 'times', 'times_point_c', 'type'
    #     ),
    #
    #     transporting_economy=ne.custom(
    #         lambda a, b, c, d, e: ((
    #             a.get('plan_travel_time_min') + b.get('plan_travel_time_min')
    #         ) * 60) * 1. / (
    #             calc_tr(c, d, e)
    #         ) - 1,
    #         'order_id_1_info', 'order_id_2_info', 'times', 'times_point_c', 'type'
    #     ),
    #     transporting_economy_abs=ne.custom(
    #         lambda a, b, c, d, e: (((
    #                                        a.get(
    #                                            'plan_travel_time_min') + b.get(
    #                                    'plan_travel_time_min')
    #                                ) * 60) - (
    #                                   calc_tr(c, d, e)
    #                               )) * (-1.),
    #         'order_id_1_info', 'order_id_2_info', 'times', 'times_point_c',
    #         'type'
    #     )
    # ).project(
    #     ne.all(),
    #     # transporting_economy=(
    #     #     lambda x, y: (x * 1. / y - 1), 'transporting_wo_batch', 'transporting_w_batch'
    #     # )
    # ).put(
    #     '//home/taxi_ml/dev/drivers/delivery/2020-11-30_moscow/task_2/for_fix_tariff'
    # )
    #
    # job.run()
    #
    # job = cluster.job('Couriers collection' + str(time.time()))
    # job = job.env(
    #     bytes_decode_mode='strict',
    #     yt_spec_defaults={'max_failed_job_count': 1000}
    # )
    #
    # job.table('//home/taxi_ml/dev/drivers/delivery/2020-11-30_moscow/task_2/for_fix_tariff').filter(
    #     nf.custom(lambda x: x > 0, 'transporting_economy')
    # ).put(
    #     '//home/taxi_ml/dev/drivers/delivery/2020-11-30_moscow/task_2/for_fix_tariff_filtered_positive_economy'
    # )
    #
    # job.run()

    job = cluster.job('Couriers collection' + str(time.time()))
    job = job.env(
        bytes_decode_mode='strict',
        yt_spec_defaults={'max_failed_job_count': 1000}
    )

    job.table(
        '//home/taxi_ml/dev/drivers/delivery/2020-11-30_moscow/task_2/for_fix_tariff_filtered_positive_economy'
    ).groupby(
        'order_id_1', 'order_id_2'
    ).aggregate(
        driving_economy=na.any('driving_economy'),
        lengths_in_km_1=na.any('lengths_in_km_1'),
        lengths_in_km_2=na.any('lengths_in_km_2'),
        transporting_economy=na.max('transporting_economy'),
        transporting_economy_abs=na.last('transporting_economy_abs', by='transporting_economy')
    ).put(
        '//home/taxi_ml/dev/drivers/delivery/2020-11-30_moscow/task_2/for_fix_tariff_opt_marshrut'
    )

    job.run()

    job = cluster.job('Couriers collection' + str(time.time()))
    job = job.env(
        bytes_decode_mode='strict',
        yt_spec_defaults={'max_failed_job_count': 1000}
    )

    job.table(
        '//home/taxi_ml/dev/drivers/delivery/2020-11-30_moscow/task_2/for_fix_tariff_opt_marshrut'
    ).groupby(
        'order_id_1'
    ).aggregate(
        driving_economy=na.mean('driving_economy'),
        lengths_in_km_1=na.mean('lengths_in_km_1'),
        lengths_in_km_2=na.mean('lengths_in_km_2'),
        transporting_economy=na.mean('transporting_economy'),
        transporting_economy_abs=na.mean('transporting_economy_abs')
    ).put(
        '//home/taxi_ml/dev/drivers/delivery/2020-11-30_moscow/task_2/for_fix_tariff_opt_marshrut_by_1'
    )

    job.run()

    job = cluster.job('Couriers collection' + str(time.time()))
    job = job.env(
        bytes_decode_mode='strict',
        yt_spec_defaults={'max_failed_job_count': 1000}
    )

    job.table(
        '//home/taxi_ml/dev/drivers/delivery/2020-11-30_moscow/task_2/for_fix_tariff_opt_marshrut'
    ).groupby(
        'order_id_1'
    ).reduce(get_last_record_reducer).put(
        '//home/taxi_ml/dev/drivers/delivery/2020-11-30_moscow/task_2/for_fix_tariff_opt_marshrut_by_1_tmp'
    )

    job.run()

    job = cluster.job('Couriers collection' + str(time.time()))
    job = job.env(
        bytes_decode_mode='strict',
        yt_spec_defaults={'max_failed_job_count': 1000}
    )

    job.table(
        '//home/taxi_ml/dev/drivers/delivery/2020-11-30_moscow/task_2/for_fix_tariff_opt_marshrut_by_1_tmp'
    ).groupby(
        'order_id_2'
    ).reduce(get_last_record_reducer).put(
        '//home/taxi_ml/dev/drivers/delivery/2020-11-30_moscow/task_2/for_fix_tariff_opt_marshrut_by_1_tmp_2'
    )


    job.run()