import argparse
import sys
import time

import datetime
import geohash as gh
import numpy as np

from nile.api.v1 import extractors as ne
from nile.api.v1 import filters as nf
from nile.api.v1 import Record
from qb2.api.v1 import filters as qf
from qb2.api.v1 import typing as qt

from projects.burnt_orders_research.ld_parsing_nile_block import \
    LogEntrySegmentDetector
from projects.burnt_orders_research.nile_blocks.main import (
    calc_different_stats, dup_mapper, main_reducer, get_all_waybills_by_segment,
    claim_ld_stages_mapper, claims_pattern_ld_stages_reducer
)
from projects.burnt_orders_research.nile_blocks.propositions import(
    proposition_reducer,
)
from projects.common.nile.dates import range_selector
from projects.data_sources.data_context.eda_logs import \
    DataContext as EdaOrdersDataContext
from projects.data_sources.data_context.raw_services_logs import \
    DataContext as RawLogs
from projects.data_sources.data_context.cargo import \
    DataContext as CargoLogsDataContext
from projects.data_sources.data_context.raworders_dmorders_sessions import \
    DataContext as TaxiOrdersDataContext
from projects.efficiency_metrics.project_config import get_project_cluster
from projects.burnt_orders_research.ld_parsing_nile_block import \
    LogEntrySegmentDetector

from collections import deque


def _reducer(groups):
    for key, records in groups:

        # queue_orders = deque()
        # queue_couriers = deque()

        common_queue = []
        for record in records:
            if (record.get('utc_delivered_dttm') is not None):
                common_queue.append(
                    [record.utc_delivered_dttm, 'couriers', record.courier_type,
                     record.courier_id])
            elif record.get('location_lat') is not None:

                common_queue.append([record.utc_claim_created_dttm, 'orders',
                      record.get('taxi_dispatch_cargo_uuid_id'),

                      # "location_lat"
                      # "location_lon"
                      # "order_location_lat"
                      # "order_location_lon"
                      record.get('location_lat'),
                      record.get('location_lon'),
                      record.get('order_location_lat'),
                      record.get('order_location_lon'),

                      ]
            )

            # queue_couriers
        yield Record(
            key,
            common_queue=common_queue
        )

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--yt-proxy', type=str, default='Hahn')
    # parser.add_argument('--from-date', type=str, required=True)
    # parser.add_argument('--to-date', type=str, required=True)

    args = parser.parse_args()

    cluster = get_project_cluster()

    # from_date = '2021-04-12'
    # from_date = '2021-05-05'
    # from_date = '2021-05-12'
    # to_date = '2021-05-26'
    # from_date = datetime.datetime.strptime(args.from_date, '%Y-%m-%d')
    # to_date = datetime.datetime.strptime(args.to_date, '%Y-%m-%d')

    yt_path_to = '//home/taxi-delivery/analytics/dev/lavka_ld/london_batching_upd/{}'

    job = cluster.job('LD metrics' + str(time.time()))
    job = job.env(bytes_decode_mode='strict')

    CUR_DATE = '2021-08-05'

    # couriers = job.table('//home/eda-dwh/cdm/order/dm_lavka_order/2021').project(
    #     'utc_claim_created_dttm',
    #     'place_id',
    #     'courier_id',
    #     'courier_type',
    #     'utc_delivered_dttm', # курьер оказался в точке Лавки
    #     # 'utc_courier_assigned_dttm', # performer_draft
    #     # 'utc_adopted_by_courier_dttm', # курьер согласился (прожал ПРО)
    # ).filter(
    #     qf.defined('utc_claim_created_dttm'),
    #     nf.custom(lambda x: x[:10] == CUR_DATE, 'utc_claim_created_dttm'),
    #     qf.defined('courier_id')
    # ).put(
    #     yt_path_to.format('couriers')
    # )
    # .join(
    #         job.table('//home/taxi-dwh/ods/dbdrivers/executor_profile/executor_profile').filter(
    #             qf.defined('bigfood_courier_id')
    #         ).project(
    #             dbid_uuid=n
    #         ),
    #         by_left='courier_id', by_right='bigfood_courier_id', type='inner'
    #     )

    # TODO - не учитываются случаи, когда курьер вышел на линию -
    #  не учитываются все уникальные водители, совершиве хотя бы 1 заказ
    #  их будет столко, сколько курьеров  уникальных. но может быть они обедают


    segments = job.table('//home/taxi/production/replica/postgres/cargo_dispatch/segments').project(
        "claim_id",
        "segment_id",
        "chosen_waybill"
    )

    # orders = job.table('//home/eda-dwh/cdm/order/dm_lavka_order/2021').project(
    #     'taxi_dispatch_cargo_uuid_id',
    #     'utc_claim_created_dttm',
    #     'order_location_lat',
    #     'order_location_lon',
    #     'place_id',
    #     'place_address_short'
    # ).join(
    #     job.table('//home/eda-dwh/ods/wms/store/store').project(
    #         'location_lon', 'location_lat',
    #         place_id=ne.custom(lambda x: int(x), 'store_id')
    #     ), by='place_id', type='inner'
    # ).filter(
    #     qf.defined('utc_claim_created_dttm'),
    #     nf.custom(lambda x: x[:10] == CUR_DATE, 'utc_claim_created_dttm')
    # ).join(
    #     segments, by_left='taxi_dispatch_cargo_uuid_id', by_right='claim_id', type='inner'
    # ).put(
    #     yt_path_to.format('orders')
    # )

    def _detect_batches_reducer(groups):
        for key, records in groups:
            n = 0
            b = []
            for record in records:
                n += 1
                b.append(record.taxi_dispatch_cargo_uuid_id)
            yield Record(
                key,
                n=n,
                b=b
            )
    # job.table(yt_path_to.format('orders')).filter(
    #     qf.defined('chosen_waybill')
    # ).groupby(
    #     'chosen_waybill'
    # ).reduce(
    #     _detect_batches_reducer
    # ).put(yt_path_to.format('orders_ext'))

    # job.table(yt_path_to.format('orders')).join(
    #     job.table(yt_path_to.format('orders_ext')), by='chosen_waybill', type='inner'
    # ).put(
    #     yt_path_to.format('orders_batch')
    # )

    def _red_reducer(groups):
        for key, records in groups:
            cnt = Counter()
            for record in records:
                cnt[('batch' if (record.n > 1) else 'unique')]+=1

            yield Record(
                key,
                **cnt
            )

    # job.table('//home/taxi-delivery/analytics/dev/lavka_ld/london_batching/orders_batch').groupby(
    #     'place_id'
    # ).reduce(
    #     _red_reducer
    # ).put(
    #     '//home/taxi-delivery/analytics/dev/lavka_ld/london_batching/orders_batch_n'
    # )
    #
    # job.concat(*[
    #     job.table(yt_path_to.format('orders')).project(
    #         ne.all(),
    #         event_dttm='utc_claim_created_dttm'
    #     ).filter(
    #         qf.defined('taxi_dispatch_cargo_uuid_id'),
    #         qf.defined("location_lat"),
    #         qf.defined("location_lon"),
    #         qf.defined("order_location_lat"),
    #         qf.defined("order_location_lon")
    #     ),
    #     job.table(yt_path_to.format('couriers')).project(
    #         ne.all(),
    #         event_dttm='utc_delivered_dttm'
    #     )
    # ]).groupby(
    #     'place_id'
    # ).sort('event_dttm').reduce(
    #     _reducer
    # ).put(
    #     yt_path_to.format('places_sorted')
    # )

    # from geopy.distance import geodesic
    # from geopy.distance import geodesic
    from haversine import haversine, Unit


    def diff(order_1, order_2):
        # "location_lat"
        # "location_lon"
        # "order_location_lat"
        # "order_location_lon"
        #     print order_1
        a = (order_1[3], order_1[4])

        b1 = (order_1[5], order_1[6])

        b2 = (order_2[5], order_2[6])

        # a_b1 = (geodesic(a, b1).meters)
        # b1_b2 = (geodesic(b1, b2).meters)
        # b2_b1 = (geodesic(b2, b1).meters)
        # a_b2 = (geodesic(a, b2).meters)

        a_b1 = (haversine(a, b1) * 1. / 1000)
        b1_b2 = (haversine(b1, b2) * 1. / 1000)
        b2_b1 = (haversine(b2, b1) * 1. / 1000)
        a_b2 = (haversine(a, b2) * 1. / 1000)

        CONST = 800

        if ((a_b1 + b1_b2) < a_b2 + CONST):
            return True
        if ((a_b2 + b2_b1) < a_b1 + CONST):
            return True


    def try_assignment(queue_orders, queue_couriers):
        if len(queue_orders) and len(queue_couriers):
            #
            # if len(queue_couriers) > 0:
            # значит заказов не было
            courier = queue_couriers[0]
            # assignment
            # courier -> queue_orders[0]
            if len(queue_orders) > 1:
                for i in range(1, len(queue_orders)):
                    # courier -> order
                    if diff(queue_orders[i], queue_orders[0]) == True:
                        # batch
                        # убрать из очереди то, что назначили queue_orders, queue_couriers
                        queue_couriers.pop(0)
                        queue_orders.pop(i)
                        queue_orders.pop(0)

                        # print 'batch'
                        return (queue_orders, queue_couriers, 'batch')

            queue_orders.pop(0)
            queue_couriers.pop(0)
            # убрать из очереди то, что назначили queue_orders, queue_couriers
            # print 'unique'
            return (queue_orders, queue_couriers, 'unique')
        else:
            return (queue_orders, queue_couriers, 'none')

    from collections import Counter
    def _mapper(records):
        for record in records:

            queue_orders = []
            queue_couriers = []

            common_queue = record.get('common_queue')
            cnt = Counter()
            cnt_tr_type = Counter()

            for el in common_queue:
                #     print datetime.datetime.strptime(el[0], '%Y-%m-%d %H:%M:%S')
                while len(queue_orders) and (
                        datetime.datetime.strptime(el[0], '%Y-%m-%d %H:%M:%S') -
                        datetime.datetime.strptime(queue_orders[0][0],
                                                   '%Y-%m-%d %H:%M:%S')
                ).seconds > 30 * 60:
                    queue_orders.pop(0)

                # if courier
                if el[1] == 'couriers':
                    queue_couriers.append(el)

                    cnt_tr_type[el[2]] += 1
                # if order
                else:
                    # добавляю в заказ
                    queue_orders.append(el)

                # пытаюсь назначить
                queue_orders, queue_couriers, tmp = try_assignment(queue_orders,
                                                              queue_couriers)
                cnt[tmp]+=1


            #     for el in:
            #         cnt_tr_type[el]
            # record.

            cnt['vehicle_share'] = cnt_tr_type['vehicle'] * 1. / sum(cnt_tr_type.values())
            yield Record(
                record,
                # TODO:
                **dict(cnt)

            )

    job.table(yt_path_to.format('places_sorted')).map(
        _mapper
    ).project(ne.all(['common_queue'])).put(yt_path_to.format('places_sorted_mapped'))

    job.run()



