import argparse
import sys
import time

import datetime
import geohash as gh
import numpy as np

from nile.api.v1 import extractors as ne
from nile.api.v1 import filters as nf
from nile.api.v1 import Record
from qb2.api.v1 import filters as qf
from qb2.api.v1 import typing as qt

from projects.burnt_orders_research.ld_parsing_nile_block import \
    LogEntrySegmentDetector
from projects.burnt_orders_research.nile_blocks.main import (
    calc_different_stats, dup_mapper, main_reducer, get_all_waybills_by_segment,
    claim_ld_stages_mapper, claims_pattern_ld_stages_reducer
)
from projects.burnt_orders_research.nile_blocks.propositions import(
    proposition_reducer,
)
from projects.common.nile.dates import range_selector
from projects.data_sources.data_context.eda_logs import \
    DataContext as EdaOrdersDataContext
from projects.data_sources.data_context.raw_services_logs import \
    DataContext as RawLogs
from projects.data_sources.data_context.cargo import \
    DataContext as CargoLogsDataContext
from projects.data_sources.data_context.raworders_dmorders_sessions import \
    DataContext as TaxiOrdersDataContext
from projects.efficiency_metrics.project_config import get_project_cluster
from projects.burnt_orders_research.ld_parsing_nile_block import \
    LogEntrySegmentDetector


# CUSTOM_PARSED_LD_LOGS = (
#     '//home/taxi-delivery/analytics/production/parsed_logs/segments_2_ld_parsed/{}'
# )


def get_propositions(job, from_date, to_date):
    raw_logs_dc = RawLogs(job, from_date, to_date)

    raw_logs_dc.get_ld_route_propositions().groupby(
        'proposition_id'
    ).sort('history_timestamp').reduce(
        proposition_reducer
    ).put(
        yt_path_to.format(
            'propositions_')
        # yt_path_to.format('propositions_{}'.format(to_date.strftime('%Y-%m-%d')))
    )
    return job


def get_ld_basic_stats_and_timeline(job, from_date, to_date):

    raw_logs_dc = RawLogs(
        job,
        from_date,
        to_date
    )
    (
        raw_logs_dc.get_ld_dispatch(to_date.strftime('%Y-%m-%d'))
            .filter(
            qf.defined('segment_id')
        )
            .map(
            LogEntrySegmentDetector()
        ).filter(
            qf.defined('segment_id')
        ).groupby(
            'segment_id'
        ).sort(
            'timestamp_ld'
        ).reduce(
            calc_different_stats
        )
        .put(
            yt_path_to.format(
                'calc_different_stats')
            # yt_path_to.format('calc_different_stats_{}'.format(to_date.strftime('%Y-%m-%d')))
        )
    )
    return job


def get_waybills(job, from_date, to_date):
    cargo_d_c = CargoLogsDataContext(job, from_date, to_date)

    # claims = cargo_d_c.get_claims().project(
    #     'taxi_order_id', 'uuid_id', 'is_delayed', 'due', 'timestamp'
    # )
    # segments = cargo_d_c.get_segments()
    # waybills_segments = cargo_d_c.get_waybills_segments()

    cargo_d_c.get_waybills_segments().join(
        cargo_d_c.get_waybills(),
        by_right='external_ref', by_left='waybill_external_ref',
        type='inner'
    ).groupby(
        'segment_id'
    ).sort(
        'waybill_building_version', 'created_ts'
    ).reduce(
        get_all_waybills_by_segment
    ).put(
        yt_path_to.format('all_waybills_by_segment')
    )
    return job


def get_table_with_claims(job, from_date, to_date):

    cargo_d_c = CargoLogsDataContext(job, from_date, to_date)

    claims = cargo_d_c.get_claims().project(
        'taxi_order_id', 'uuid_id', 'is_delayed', 'due', 'timestamp'
    )
    segments = cargo_d_c.get_segments()

    # return (
    #     claims#.filter(
    #     #     nf.custom(lambda x: x != 'platform_usage', 'claim_kind')
    #     # )
    #     .join(
    #         segments,
    #         by_left='uuid_id', by_right='claim_id', type='left',
    #     ).project(
    #         ne.all(),
    #         utc_date=ne.custom(
    #             lambda x:
    #             datetime.datetime.utcfromtimestamp(x).strftime('%Y-%m-%d'),
    #             'timestamp'
    #         ),
    #         utc_date_hour=ne.custom(
    #             lambda x:
    #             datetime.datetime.utcfromtimestamp(x).strftime('%Y-%m-%d %H'),
    #             'timestamp'
    #         )
    #     ).filter(
    #         nf.custom(lambda x: x >= from_date.strftime('%Y-%m-%d'), 'utc_date'),
    #         nf.custom(lambda x: x < to_date.strftime('%Y-%m-%d'), 'utc_date')
    #     ).project(
    #         "claim_id",
    #         "due",
    #         "is_delayed",
    #         "s_timestamp",
    #         "segment_id",
    #         'employer',
    #         'utc_date',
    #         'utc_date_hour'
    #     )
    # )

    tmp = claims.join(
        segments,
        by_left='uuid_id', by_right='claim_id', type='left',
    ).project(
        ne.all(),
        utc_date=ne.custom(
            lambda x:
            datetime.datetime.utcfromtimestamp(x).strftime('%Y-%m-%d'),
            'timestamp'
        ),
        utc_date_hour=ne.custom(
            lambda x:
            datetime.datetime.utcfromtimestamp(x).strftime('%Y-%m-%d %H'),
            'timestamp'
        )
    ).filter(
        nf.custom(lambda x: x >= from_date.strftime('%Y-%m-%d'), 'utc_date'),
        nf.custom(lambda x: x < to_date.strftime('%Y-%m-%d'), 'utc_date')
    ).project(
        'taxi_order_id',
        'timestamp',
        "claim_id",
        "due",
        "is_delayed",
        "s_timestamp",
        "segment_id",
        'employer',
        'utc_date',
        'utc_date_hour'
    )

    return tmp

    # return tmp.join(
    #     job.table(
    #         '//home/taxi-delivery/analytics/production/ld/aliase_groups/2021-09-11_'
    #     ).project(
    #         'segment_id', 'alias', 'request_timestamp_minutes',
    #         'request_timestamp', alias_dttm='timestamp'
    #     ), type='left', by = 'segment_id'
    # ).put(
    #     '//home/taxi-delivery/analytics/production/ld/aliase_groups/2021-09-11_check'
    # )

    # return (
    #     job.table(
    #         '//home/taxi-delivery/analytics/production/ld_ab_test/monitorings/realtime_schema'
    #     ).filter(
    #         nf.custom(
    #             lambda x: x in [
    #                 'ld_vs_taxi_dispatch_experiment_2021-03-11',
    #                 'corps_v4_31_03',
    #                 'ld_vs_taxi_dispatch_experiment_2021-04-14',
    #                 'russia_cities_15_04',
    #                 'ld_80_20_1205_test',
    #                 'russia_spb_vv_test',
    #                 'svc_full_ld',
    #                 'msc_ld_vs_taxi_test',
    #                 'c2c_test',
    #                 'segment_routers_testing_alive_batches_samara_voronezh',
    #                 'ld_meznar_test',
    #                 'msc_ld_vs_taxi_test',
    #                 'msc_switchback_exp'
    #             ],
    #             'group'
    #         )
    #     ).join(
    #         tmp, type='right', by='claim_id'
    #     )
    # )


def find_id(queue):
    _id = None
    for el in queue:
        if (_id is None) and (el[3] != ''):
            _id = el[3]
            break
    return _id


def tmp():
    pass
   # almost_final = arr.project(
    #     ne.all(),
    #     last_timestamp_in_ld=ne.custom(
    #         lambda x: x[-1][1] if ((x is not None) and (len(x) > 0)) else None,
    #         'timeline'
    #     ),
    #     sec_till_ld_finish=ne.custom(
    #         lambda x, y:
    #         (x[-1][1] - y) if ((x is not None) and (len(x) > 0)) else None,
    #         'timeline', 'timestamp'
    #     ),
    #     empty_interval_in_ld_working=ne.custom(lambda x: max(
    #             np.array([el[1] for el in x][1:])
    #             -
    #             np.array([el[1] for el in x][:-1])
    #     ) if ((x is not None) and (len(x) > 1)) else None, 'timeline')
    # ).join(
    #     job.table(yt_path_to.format('all_waybills_by_segment')).filter(
    #         qf.defined('segment_id')
    #     ),
    #     by='segment_id', type='left', assume_unique_right=True
    # ).put(
    #     (yt_path_to.format('calc_different_stats_restricted_reorder'))
    # )

    # almost_final = job.table(
    #     (yt_path_to.format('calc_different_stats_restricted_reorder[#0:#1]'))
    # )
    # print (yt_path_to.format('calc_different_stats_restricted_reorder'))
    #
    # final = almost_final.map(
    #     dup_mapper
    # ).groupby(
    #     'utc_date', 'zone_id', 'group', 'employer_code'
    # ).reduce(
    #     main_reducer
    # )#.put(
    # #     yt_path_to.format('ld_abs_metrics_by_tz_date')
    # # )
    #
    #
    # # almost_final = job.table(
    # #     yt_path_to.format('ld_abs_metrics_by_tz_date')
    # # )
    #
    # almost_final.map(
    #     dup_mapper
    # ).groupby(
    #     'utc_date', 'zone_id', 'group', 'employer_code'
    # ).reduce(
    #     main_reducer
    # ).project(
    #     *(schema.keys())
    # ).put(
    #     yt_path_to.format('ld_abs_metrics_by_tz_date_schema'),
    #     schema=schema
    # )



if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--yt-proxy', type=str, default='Hahn')
    parser.add_argument('--from-date', type=str, required=True)
    parser.add_argument('--to-date', type=str, required=True)

    args = parser.parse_args()

    cluster = get_project_cluster()

    # from_date = '2021-04-12'
    # from_date = '2021-05-05'
    # from_date = '2021-05-12'
    # to_date = '2021-05-26'
    from_date = datetime.datetime.strptime(args.from_date, '%Y-%m-%d')
    to_date = datetime.datetime.strptime(args.to_date, '%Y-%m-%d')

    yt_path_to = '//home/taxi-delivery/analytics/production/ld/dispatch_metrics/{}'

    job = cluster.job('LD metrics' + str(time.time()))
    job = job.env(bytes_decode_mode='strict')

    # job = get_propositions(job, from_date, to_date)
    job = get_ld_basic_stats_and_timeline(job, from_date, to_date)
    # job = get_waybills(job, from_date, to_date)

    from collections import Counter
    def _calc_(groups):
        for key, records in groups:
            tmp = Counter()
            regular_candidates = set()
            for record in records:
                tmp[record.kind] += 1
                if record.kind == 'regular':
                    regular_candidates.add(record.dbid_uuid)

            s = sum(tmp.values())
            f = {}
            for k, v in tmp.items():
                f[k] = v #* 1. / s
            yield Record(
                key,
                timestamp=record.timestamp,
                f=f,
                regular_candidates = list(regular_candidates)
            )

    # job.table('//home/taxi-delivery/analytics/production/ld_logs/parsed_edges/2021-08-12').groupby(
    #     'cargo_ref_id', 'gamble_id'#, 'timestamp'
    # ).reduce(
    #     _calc_
    # ).put(
    #     yt_path_to.format('edg')
    # )

    def _how_many_regulars(groups):
        for key, records in groups:
            unique_regulars = set()
            fst_ts = None
            for record in records:
                if fst_ts is None:
                    fst_ts = record.timestamp
                pass
            yield Record(

            )


    # job.table(
    #     yt_path_to.format('edg')
    # ).groupby(
    #     'cargo_ref_if'
    # ).sort('timestamp').reduce(
    #     _how_many_regulars
    # ).put(
    #     yt_path_to.format('edg_regs')
    # )
    # #
    job.run()

    job = cluster.job('LD metrics' + str(time.time()))
    job = job.env(bytes_decode_mode='strict')




    # yt_path_to.format('calc_different_stats_{}'.format(to_date.strftime('%Y-%m-%d')))
    arr = job.table(yt_path_to.format('calc_different_stats')).join(
        get_table_with_claims(job, from_date, to_date),
        type='inner', by='segment_id'
        # type='right', by='segment_id'
    ).put(
        yt_path_to.format('calc_different_stats_2_claims_')
    )

    job.run()
    #
    job = cluster.job('LD metrics' + str(time.time()))
    job = job.env(bytes_decode_mode='strict')


    # job.table(yt_path_to.format('calc_different_stats_2_claims'))v

    arr = job.table(yt_path_to.format('calc_different_stats_2_claims_'))
    ld_stats = arr.map(
    # ld_stats = almost_final.map(
        claim_ld_stages_mapper, memory_limit=100000,
    ).filter(
        qf.defined('taxi_order_id')
    )#.put(
    #
    #     yt_path_to.format('claims_ld_stages_{}'.format(to_date.strftime('%Y-%m-%d')))
    # )

    #ld_stats = job.table(yt_path_to.format('claims_ld_stages_{}'.format(to_date.strftime('%Y-%m-%d'))))
    propositions = job.table(
        # yt_path_to.format(
        #     'propositions_{}'.format(to_date.strftime('%Y-%m-%d')))
        yt_path_to.format('propositions_')
    ).project(
        "approved_or_rejected_ts",
        "fst_approved_or_reject_ts",
        "fst_ts_order_id_is_not_none",
        "proposition_id",
        "proposition_ts_created",
        # "queue",
        "sec_from_created_proposition_to_order_id",
        "sec_from_order_id_to_1st_reject_or_approve",
        # "sec_full_cycle_proposition",
        taxi_order_id=ne.custom(lambda x: find_id(x), 'queue')
    ).filter(
        qf.defined('taxi_order_id')
    )

    schema = {
        'segment_id': qt.String,
        'dispatch_zone_id': qt.String,
        'corp_client_id': qt.String,
        'employer_code': qt.String,
        'group': qt.String,
        "sec_till_1st_assignment": qt.Float,
        "sec_till_1st_candidates": qt.Float,
        "sec_till_1st_edges": qt.Float,
        "sec_till_1st_propose_request": qt.Float,
        "sec_till_1st_segment_info": qt.Float,
        'sec_from_created_proposition_to_order_id': qt.Float,
        'sec_from_order_id_to_1st_reject_or_approve': qt.Float,
        'utc_date': qt.String,
        'utc_date_hour': qt.String,
        'is_delayed': qt.Bool,
        'n_steps_candidates': qt.Int64,
        'n_steps_edges': qt.Int64,
        'n_steps_assignment': qt.Int64,
        'n_candidates_per_edge': qt.Float,
        'share_regular_edges': qt.Float,
        'most_popular_reject_reason': qt.String
    }

    propositions.join(
        ld_stats,
        type='inner', by='taxi_order_id'
    ).project(
        # 'segment_id',
        # 'dispatch_zone_id',
        # 'corp_client_id',
        # 'employer_code',
        # "sec_till_1st_assignment",
        # "sec_till_1st_candidates",
        # "sec_till_1st_edges",
        # "sec_till_1st_propose_request",
        # "sec_till_1st_segment_info",
        # 'sec_from_created_proposition_to_order_id',
        # 'sec_from_order_id_to_1st_reject_or_approve',
        # 'utc_date',
        # 'utc_date_hour',
        # 'is_delayed'
        *schema.keys()
    ).put(
        yt_path_to.format(
            'dispatch_mertics_schema'),
        schema=schema
    )


    # # TODO: nf.custom(lambda x: x not in ["delivered_finish", "returned_finished"], 'status')
    # tmp.groupby(
    #     'queue_str'
    # ).reduce(
    #     claims_pattern_ld_stages_reducer
    # ).sort(
    #     'minus_cnt'
    # ).put(
    #     yt_path_to.format('claims_pattern_ld_stages')
    # )

    ##### TTAXI DISPATCH STARTS HERE

    # taxi_orders = TaxiOrdersDataContext(
    #     job,
    #         datetime.datetime.strptime(
    #             from_date, '%Y-%m-%d')
    #         ,
    #         datetime.datetime.strptime(
    #             to_date, '%Y-%m-%d'
    #         ),
    # ).get_orders().project(
    #     'driver_search_duration_sec',
    #     'driver_tech_search_duration_sec',
    #     'driver_reorder_cnt',
    #     taxi_order_id='order_id', taxi_utc_order_dttm='utc_order_dttm',
    # ).put(yt_path_to.format('dm_order_tmp'))
    #

    # taxi_orders = job.table(yt_path_to.format('dm_order_tmp'))
    # taxi_orders.join(
    #     job.table('//home/taxi-delivery/analytics/production/ld_ab_test/monitorings/realtime_schema'),
    #     by='taxi_order_id', type='inner'
    # ).put(
    #     yt_path_to.format('final_taxi_dispattch_tmp')
    # )

    # job.table(
    #     yt_path_to.format('final_taxi_dispattch_tmp')
    # ).join(
    #     job.table(yt_path_to.format('tmp_segments_claims')),
    #     by='chosen_waybill', type='inner'
    # ).put(
    #     yt_path_to.format('final_taxi_dispattch')
    # )

    # job = cluster.job('Couriers collection' + str(time.time()))
    # job = job.env(
    #     bytes_decode_mode='strict',
    #     yt_spec_defaults={'max_failed_job_count': 1000}
    # )
    # job = get_dry_run_info_burnt_orders(job)
    # job.run()

    job.run()
