import argparse
from collections import Counter
import time
import re

import datetime
import geohash as gh
from nile.api.v1 import aggregators as na
from nile.api.v1 import extractors as ne
from nile.api.v1 import filters as nf
from nile.api.v1 import Record
from qb2.api.v1 import filters as qf
import numpy as np

from projects.common.nile.dates import range_selector
from projects.batching.batching_custom_params import (
    VKUSVILL_CORP_CLIENT_ID, VV_ADDRESSES, VV_ADDRESSES_IDS,
    YARCHE_CORP_CLIENT_ID, YARCHE_ADDRESSES
)
from projects.data_sources.data_context.cargo import \
    DataContext as CargoLogsDataContext
from projects.data_sources.data_context.raworders_dmorders_sessions import \
    DataContext as OrdersSessionsLogsDataContext

from projects.data_sources.data_context.raw_services_logs import \
    DataContext as RawServicesLogsDataContext
from projects.efficiency_metrics.project_config import get_project_cluster
from projects.batching.nile_blocks.corps import calc_a_density_reducer
from projects.data_sources.data_context.raw_services_logs import \
    DataContext as RawLogs
from projects.common.time_utils import datetime_2_timestamp, \
    parse_timestring


# def calc_for_courier_express_surge_in_econom(groups):
#     for key, records in groups:
#         # tmp = None
#         econom_surge = None
#         main_tariff = None
#         for record in records:
#             if (
#                     # ('express' in record.user_tariff_classes) or
#                     # ('courier' in record.user_tariff_classes) or
#                     ('econom' in record.user_tariff_classes)
#             ) and (econom_surge is None):
#                 econom_surge = record['surge_value']
#             if (
#                     ('express' in record.user_tariff_classes) or
#                     ('courier' in record.user_tariff_classes)
#             ):
#                 main_tariff = record.user_tariff_classes[0]
#
#
#         d = record.to_dict()
#         del d['user_tariff_classes']
#         del d['tariff']
#         del d['surge_value']
#         d['econom_surge'] = econom_surge
#         d['main_tariff'] = main_tariff
#         if main_tariff is not None:
#             yield Record(
#                 key,
#                 **d
#             )
#

# job.table('//home/taxi-delivery/analytics/dev/taxi_supply_damage/intro').groupby(
#     'order_id'
# ).reduce(
#     calc_for_courier_express_surge_in_econom
# ).put(
#     '//home/taxi-delivery/analytics/dev/taxi_supply_damage/for_courier_express_sruge_in_econom'
# )

def calc_ts_search(due, taxi_utc_order_dttm, timestamp, utc_order_due_dttm):
    if due is not None:
        return max(int(due) - 30 * 60, timestamp)
    elif utc_order_due_dttm is not None:
        return max(
            datetime_2_timestamp(
                parse_timestring(utc_order_due_dttm, 'UTC')) - 30 * 60,
            datetime_2_timestamp(
                parse_timestring(taxi_utc_order_dttm, 'UTC'))
        )
    elif timestamp is not None:
        return timestamp
    else:
        return datetime_2_timestamp(
            parse_timestring(utc_order_due_dttm, 'UTC'))


def tmp(last_record, future_record, target_record):
    if last_record is not None and future_record is not None:
        top_diff_ts = (
                future_record.timestamp_due - target_record.timestamp_due
        )
        bottom_diff_ts = (
                target_record.timestamp_due - last_record.timestamp_due
        )

        if top_diff_ts <= bottom_diff_ts:
            d = target_record.to_dict()
            for k, v in future_record.to_dict().items():
                d[k] = v
            d['diff_ts'] = top_diff_ts
            return d
        else:
            d = target_record.to_dict()
            for k, v in last_record.to_dict().items():
                d[k] = v
            d['diff_ts'] = bottom_diff_ts
            return d
    elif last_record is not None:
        d = target_record.to_dict()
        for k, v in last_record.to_dict().items():
            d[k] = v
        d['diff_ts'] = (
                target_record.timestamp_due - last_record.timestamp_due
        )
        return d
    elif future_record is not None:
        d = target_record.to_dict()
        for k, v in future_record.to_dict().items():
            d[k] = v
        d['diff_ts'] = (
                future_record.timestamp_due - target_record.timestamp_due
        )
        return d


def find_nearest_timestamp_due(groups):
    for key, records in groups:

        last_record = None
        future_record = None

        target_records = []

        for record in records:
            if (len(target_records)):
                if record.get('event') != 'target':
                    future_record = Record(record)
                    for each_target_record in target_records:
                        d = tmp(last_record, future_record, each_target_record)
                        yield Record(
                            **d
                        )
                    target_records = []
                    last_record = future_record
                    future_record = None

                else:
                    target_records.append(record)

            elif record.get('event') == 'target':
                target_records.append(record)
            else:
                # in the end
                last_record = Record(record)


def get_cargo_claims(job):
    cargo_d_c = CargoLogsDataContext(
        job,
        from_dttm,
        to_dttm
    )

    points = cargo_d_c.get_points().project(ne.all(['timestamp', 'source']))

    claim_points = cargo_d_c.get_claim_points().project(
        ne.all(['timestamp', 'source'])
    )

    b2b_orders = (
        cargo_d_c.get_claims()
            # .filter(
            #     nf.custom(lambda x: x != 'platform_usage', 'claim_kind'),
            #     # nf.custom(lambda x: x is not None, 'taxi_order_id')
            # )
            .project(
                'taxi_order_id', 'uuid_id', 'timestamp', 'final_price',
                'corp_client_id',
                'zone_id', 'dispatch_flow', 'final_pricing_calc_id',
                'is_delayed', 'due', 'eta', 'currency',
                # 'last_status_change_ts',
                log_status='status',
                claim_id='id',
                utc_date=ne.custom(
                    lambda x:
                    datetime.datetime.utcfromtimestamp(x).strftime('%Y-%m-%d'),
                    'timestamp'
                ),
                # hour=ne.custom(
                #     lambda x:
                #     datetime.datetime.utcfromtimestamp(x).hour,
                #     'timestamp'
                # ),
                utc_date_hour=ne.custom(
                    lambda x:
                    datetime.datetime.utcfromtimestamp(x).strftime('%Y-%m-%d %H'),
                    'timestamp'
                ),
                is_success_order=ne.custom(
                    lambda x: x in ['delivered_finish', 'returned_finish'],
                    'status'
                ),
                # is_b2b_order=ne.const(True),
                # expired_order_flg=ne.custom(
                #     lambda x: x not in ['delivered_finish', 'returned_finish'],
                #     'status'
                # )
            )
        #     .join(
        #     cargo_d_c.get_segments().project(ne.all(['timestamp', 'source'])),
        #     # by='claim_id', type='inner'
        #     by_left='uuid_id', by_right='claim_id', type='inner'
        # )
            .join(
            claim_points, by='claim_id', type='inner'
        ).join(
            points, by='point_id', type='inner'
        ).filter(
            nf.custom(lambda x: x == 'source', 'type')
        )
    ).put(
        args.yt_path_dir_to.format('all_cargo_claims')
            # '//home/taxi-delivery/analytics/dev/taxi_supply_damage/all_cargo_claims'
        )
    return job


def get_surge_table(job):

    ECONOM_TARIFF = 'econom'

    tmp_1 = job.table('//home/taxi-dwh/ods/mdb/order_offer_price/{}'.format(
        range_selector(from_dttm, to_dttm, '%Y-%m-%d')
    ), ignore_missing=True).project(
        'utc_created_dttm', 'tariff', 'offer_id', 'surge_value'
    ).filter(
        nf.equals('tariff', ECONOM_TARIFF)
    )

    tmp_2 = job.table('//home/taxi-dwh/ods/mdb/order/{}'.format(
        range_selector(from_dttm, to_dttm, '%Y-%m-01')
    ), ignore_missing=True).project(
        #'success_order_flg', 'corp_client_id',
        'order_id', 'offer_id', 'user_tariff_classes',#, 'tariff_zone',
        s_source_lat='source_lat', s_source_lon='source_lon',
        #'driver_uuid', 'taximeter_park_id'
    )

    tmp_1.join(tmp_2, by='offer_id', type='inner').put(
        # '//home/taxi-delivery/analytics/dev/taxi_supply_damage/intro'
        args.yt_path_dir_to.format('intro')
    )

    return job


def _parse_proposition_reducer(groups):
    for key, records in groups:
        queue = []
        for record in records:
            queue.append(
                [
                    record.history_timestamp,
                    record.history_action,
                    record.order_id
                ]
            )
        yield Record(
            key,
            queue=queue
        )

def calc_ts_diff_mapper(records):
    for record in records:
        queue = record.queue
        if len(queue) == 0:
            continue
        else:

            order_id = queue[-1][2]

            sec_interval = int(queue[-1][0]) - int(queue[0][0])
            yield Record(
                contractor_id=record.contractor_id,
                proposition_id=record.proposition_id,
                sec_interval=sec_interval,
                final_status=queue[-1][1],
                order_id=order_id,
                utc_date=record.utc_date

            )


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--yt-proxy', type=str, default='Hahn')
    parser.add_argument('--yt-path-dir-to', type=str,
                        default='//home/taxi-delivery/analytics/dev/taxi_supply_damage_v2/{}')
    parser.add_argument('--start-date', type=str)
    parser.add_argument('--finish-date', type=str)

    args = parser.parse_args()

    cluster = get_project_cluster()

    # yt_path_dir_to = args.yt_path_dir_to
    from_dttm = datetime.datetime.strptime(args.start_date, '%Y-%m-%d')
    to_dttm = datetime.datetime.strptime(args.finish_date, '%Y-%m-%d')

    job = cluster.job('Couriers collection' + str(time.time()))
    job = job.env(
        bytes_decode_mode='strict',
        yt_spec_defaults={'max_failed_job_count': 1000}
    )

    OrdersSessionsLogsDataContext(
        job,
        from_dttm,
        to_dttm,
    ).get_orders().filter(
        # nf.custom(lambda x: x == False, 'corp_order_flg'),
        nf.custom(lambda x: x in set(['express', 'courier']), 'order_tariff')
    ).project(
        'status', 'taxi_status', #'expired_order_flg',
        'order_source_lon', 'order_source_lat',
        'corp_client_id',
        'order_tariff',
        'utc_order_due_dttm',
        'order_source',
        'utc_start_waiting_dttm',
        taxi_local_order_dttm='local_order_dttm',
        taxi_order_id='order_id', taxi_utc_order_dttm='utc_order_dttm',
        # is_b2b_order=ne.const(False),
        zone_id='tariff_zone',
        utc_date=ne.custom(
            lambda x:
            x[:10],
            'utc_order_dttm'
        ),
        utc_date_hour=ne.custom(
            lambda x:
            x[:13],
            'utc_order_dttm'
        ),
        dbid_uuid=ne.custom(
            lambda x, y: '{}_{}'.format(x, y.split('_')[1]) if (
                    (x is not None) and (y is not None)) else None,
            'db_id', 'driver_id'
        ),
    ).put(
        args.yt_path_dir_to.format('all_log_orders')
        # '//home/taxi-delivery/analytics/dev/taxi_supply_damage/all_log_orders'
    )

    # EXTENDED ORDERS
    get_cargo_claims(job)

    # TAGS
    raw_logs_dc = RawLogs(
        job,
        from_dttm,
        to_dttm
    ).get_taxi_priority().project(
        ne.all(['tags']),
        is_taxi_courier=ne.custom(lambda x: 'taxi_courier' in x, 'tags'),
        is_auto_courier=ne.custom(lambda x: 'auto_courier' in x, 'tags'),
        is_walking_courier=ne.custom(lambda x: 'walking_courier' in x, 'tags')
    ).unique(
        'dbid_uuid', 'utc_date'
    ).put(
        args.yt_path_dir_to.format('supply')
        # '//home/taxi-delivery/analytics/dev/ld_taxi_damage/supply'
    )

    get_surge_table(job)

    job.run()

    job = cluster.job('Couriers collection' + str(time.time()))
    job = job.env(
        bytes_decode_mode='strict',
        yt_spec_defaults={'max_failed_job_count': 1000}
    )
    job.table(
        args.yt_path_dir_to.format('all_log_orders')
        # '//home/taxi-delivery/analytics/dev/taxi_supply_damage/all_log_orders'
    ).join(
        job.table(
            args.yt_path_dir_to.format('all_cargo_claims')
            # '//home/taxi-delivery/analytics/dev/taxi_supply_damage/all_cargo_claims'
        ),
        type='full', by='taxi_order_id'
    ).put(
        args.yt_path_dir_to.format('all_log_orders_extended')
        # '//home/taxi-delivery/analytics/dev/taxi_supply_damage/all_log_orders_extended'
    ) # :( реордеры


    # SESSIONS
    OrdersSessionsLogsDataContext(
        job,
        from_dttm,
        to_dttm,
    ).get_sessions().join(
        job.table(
            args.yt_path_dir_to.format('supply')
            # '//home/taxi-delivery/analytics/dev/ld_taxi_damage/supply'
        ).filter(
            nf.custom(
                lambda x, y, z: (x is True) or (y is True) or (z is True),
                'is_taxi_courier', 'is_auto_courier', 'is_walking_courier'
            )
        ),
        type='inner', by=['dbid_uuid', 'utc_date']
    ).put(
        args.yt_path_dir_to.format('supply_sessions_interesting_supply')
        # '//home/taxi-delivery/analytics/dev/ld_taxi_damage/supply_sessions_interesting_supply'
    )



    job.run()

    job = cluster.job('Couriers collection' + str(time.time()))
    job = job.env(
        bytes_decode_mode='strict',
        yt_spec_defaults={'max_failed_job_count': 1000}
    )

    all_log_orders_extended = job.table(

        args.yt_path_dir_to.format('all_log_orders_extended')
        # '//home/taxi-delivery/analytics/dev/taxi_supply_damage/all_log_orders_extended'
    )

    job.concat(
        all_log_orders_extended.project(
            ne.all(),
            geo_hash_order=ne.custom(
                lambda a, b, c, d: gh.encode(a, b, precision=6)
                if (a is not None)
                else gh.encode(c, d, precision=6),
                'latitude', 'longitude', 'order_source_lat', 'order_source_lon'
            ),
            event=ne.const('target'),
            timestamp_due=ne.custom(
                lambda a, b, c, d: calc_ts_search(a, b, c, d),
                'due', 'taxi_utc_order_dttm', 'timestamp', 'utc_order_due_dttm'
            )
        ),
        job.table(
            args.yt_path_dir_to.format('intro')
            # '//home/taxi-delivery/analytics/dev/taxi_supply_damage/intro'
        ).project(
            ne.all(),
            geo_hash_order=ne.custom(
                lambda x, y: gh.encode(x, y, precision=6),
                's_source_lat', 's_source_lon'
            ),
            timestamp_due=ne.custom(
                lambda x: datetime_2_timestamp(parse_timestring(x, 'UTC')),
                'utc_created_dttm'
            )
        )
    ).groupby(
        'geo_hash_order'
    ).sort(
        'timestamp_due'
    ).reduce(
        find_nearest_timestamp_due
    ).put(
        args.yt_path_dir_to.format('log_orders_2_surge_inner_upd')
        # '//home/taxi-delivery/analytics/dev/taxi_supply_damage/log_orders_2_surge_inner_upd'
    )
    # tOdo - посмотретеь, что диффы во ремени не оч большие.
    #  сли большие, что пересчитать с увеличенным геохешом
    #  и надеяться, что по непрерывности примерно одинаковый сурж

    job.run()

    job = cluster.job('Couriers collection' + str(time.time()))
    job = job.env(
        bytes_decode_mode='strict',
        yt_spec_defaults={'max_failed_job_count': 1000}
    )


    job.table(
        args.yt_path_dir_to.format('log_orders_2_surge_inner_upd')
        # '//home/taxi-delivery/analytics/dev/taxi_supply_damage/log_orders_2_surge_inner_upd'
    ).filter(
        nf.custom(lambda x: x > 1.2, 'surge_value'),
        qf.defined('dbid_uuid')
    ).join(
        job.table(
            args.yt_path_dir_to.format('supply_sessions_interesting_supply')
            # '//home/taxi-delivery/analytics/dev/ld_taxi_damage/supply_sessions_interesting_supply'
        ).filter(
            nf.custom(lambda x: x is True, 'is_taxi_courier'),
            nf.custom(lambda x: x in ['transporting', 'driving', 'waiting'], 'status')
        ), type='inner', by_right='order_id', by_left='taxi_order_id'
    ).groupby(
        'zone_id', 'utc_date'
    ).aggregate(
        taxi_courier_duration_sec=na.sum('duration_sec')
    ).put(
        args.yt_path_dir_to.format('log_orders_2_taxists')
        # '//home/taxi-delivery/analytics/dev/taxi_supply_damage/log_orders_2_taxists'
    )

    OrdersSessionsLogsDataContext(
        job,
        from_dttm,
        to_dttm,
    ).get_sessions().filter(
        qf.defined('lcl_available_tariff_class_code_list'),
        nf.custom(lambda x: x in ['transporting', 'driving', 'waiting'],
                  'status'),
        nf.custom(lambda x: 'econom' in x, 'lcl_available_tariff_class_code_list')
    ).groupby(
        'tariff_geo_zone_code', 'utc_date'
    ).aggregate(
        all_duration_sec=na.sum('duration_sec')
    ).put(
        args.yt_path_dir_to.format('zone_sh')
        # '//home/taxi-delivery/analytics/dev/ld_taxi_damage/zone_sh'
    )



    job.table(
        args.yt_path_dir_to.format('supply_sessions_interesting_supply')
        # '//home/taxi-delivery/analytics/dev/ld_taxi_damage/supply_sessions_interesting_supply'
    ).filter(
        nf.custom(lambda x: x is True, 'is_taxi_courier'),
        nf.custom(lambda x: x in ['transporting', 'driving', 'waiting'], 'status')
    ).groupby(
        'tariff_zone', 'utc_date'
    ).aggregate(
        all_duration_sec_taxi=na.sum('duration_sec')
    ).put(
        args.yt_path_dir_to.format('zone_sh_taxi_courier')
        # '//home/taxi-delivery/analytics/dev/ld_taxi_damage/zone_sh'
    )

    job.table(
        '//home/taxi/testing/export/taxi-logistic-dispatcher-production/route_propositions_history/{}'.format(
            range_selector(from_dttm, to_dttm, '%Y-%m-%d')
        ), ignore_missing=True,
    ).project(ne.all(), utc_date=ne.custom(lambda x: datetime.datetime.utcfromtimestamp(int(x)).strftime('%Y-%m-%d'), 'history_timestamp')).groupby(
        'proposition_id', 'contractor_id', 'utc_date'
    ).sort(
        'history_timestamp'
    ).reduce(
        _parse_proposition_reducer
    ).map(
        calc_ts_diff_mapper
    ).filter(
            nf.custom(lambda x: x != '', 'contractor_id'),
    ).project(
            'sec_interval',
            'utc_date',
            dbid_uuid='contractor_id'
    ).join(
        job.table(
            args.yt_path_dir_to.format('supply')
            # '//home/taxi-delivery/analytics/dev/ld_taxi_damage/supply'
        ).filter(
            nf.custom(lambda x: x is True, 'is_taxi_courier')
        ),
        by=['dbid_uuid', 'utc_date']
    ).groupby(
        'tariff_zone', 'utc_date'
    ).aggregate(
        proposition_taxi_courier_duration_sec=na.sum('sec_interval')
    ).put(
        args.yt_path_dir_to.format('proposition_2_taxists')
        # '//home/taxi-delivery/analytics/dev/taxi_supply_damage/log_orders_2_taxists'
    )

    job.run()