from collections import Counter

from nile.api.v1 import Record
import numpy as np
import six

from projects.common.time_utils import datetime_2_timestamp, \
    parse_timestring

# TODO: rename functions here

F_V_2_TAGS = [
    "free_auto_courier_tag",
    "free_expeditor_tag",
    "free_multipoints_tag",
    "free_taxi_courier_tag",
    "free_taxi_express_tag",
    "free_thermobag_tag",
    "free_walking_courier_tag",
    "verybusy_auto_courier_tag",
    "verybusy_expeditor_tag",
    "verybusy_multipoints_tag",
    "verybusy_taxi_courier_tag",
    "verybusy_taxi_express_tag",
    "verybusy_thermobag_tag",
    "verybusy_walking_courier_tag"
]
F_V = ['free', 'verybusy']
TAGS = [
    "auto_courier_tag", "expeditor_tag", "multipoints_tag",
    "taxi_courier_tag", "taxi_express_tag",
    "thermobag_tag"                                                                                                                "walking_courier_tag"
]
D_T_W_2_TAGS = [
    "driving_auto_courier_tag",
    "driving_expeditor_tag",
    "driving_multipoints_tag",
    "driving_taxi_courier_tag",
    "driving_taxi_express_tag",
    "driving_thermobag_tag",
    "driving_walking_courier_tag",
    "transporting_auto_courier_tag",
    "transporting_expeditor_tag",
    "transporting_multipoints_tag",
    "transporting_taxi_courier_tag",
    "transporting_taxi_express_tag",
    "transporting_thermobag_tag",
    "transporting_walking_courier_tag",
    "waiting_auto_courier_tag",
    "waiting_expeditor_tag",
    "waiting_multipoints_tag",
    "waiting_taxi_courier_tag",
    "waiting_taxi_express_tag",
    "waiting_thermobag_tag",
    "waiting_walking_courier_tag"
]

D_T_W = ["driving", "transporting", "waiting"]


ORDERS_FEATURES = [
    'pending_None',
    'n_cands',
    'cancelled_driving',
    'finished_expired',
    'cancelled_waiting_normed',
    'cancelled_driving_normed',
    'assigned_waiting_normed',
    'cancelled_waiting',
    'finished_cancelled_normed',
    'assigned_transporting_normed',
    'cancelled_None_normed',
    'pending_None_normed',
    'assigned_driving_normed',
    'c_i',
    'finished_complete',
    'finished_cancelled',
    'look_for_performer_time',
    'look_for_performer_time_none_length',
    'finished_complete_normed',
    'sp',
    'assigned_transporting',
    'assigned_driving',
    'assigned_waiting',
    'cancelled_None',
    'finished_expired_normed',
    'cancelled_transporting',
    'cancelled_transporting_normed',

    'finished_failed',
    'finished_failed_normed'
]




def _reducer(groups):
    for key, records in groups:
        # go through 1-min intervals and average it
        n_events = []
        n_free = []
        n_verybusy = []
        n_orders = []

        f_v_2_tags = []
        f_v = []
        tags = []
        d_t_w_2_tags = []
        d_t_w = []

        for record in records:
            # n_free.append(record.driver_statuses.get('free', 0))
            # n_verybusy.append(record.driver_statuses.get('verybusy', 0))

            n_events.append(record.n_events  * 1.)

            d_t_w_2_tags.append([
                record.order_taxi_statuses_2_reqs.get(el, 0) * 1.
                for el in D_T_W_2_TAGS
            ])
            d_t_w.append(
                [record.order_taxi_statuses.get(el, 0) * 1. for el in D_T_W]
            )
            tags.append(
                [record.events_2_req.get(el, 0) * 1. for el in TAGS]
            )
            f_v.append(
                [record.driver_statuses.get(el, 0) * 1. for el in F_V]
            )
            f_v_2_tags.append(
                [record.driver_statuses_2_reqs.get(el, 0) * 1. for el in F_V_2_TAGS]
            )
            n_orders.append(sum(record.order_taxi_statuses.values()) * 1.)

        final = {
            k: v for k, v in zip(
                F_V_2_TAGS + F_V + TAGS + D_T_W_2_TAGS + D_T_W,

                list(np.mean(f_v_2_tags, 0)) + list(np.mean(f_v, 0)) +
                list(np.mean(tags, 0)) + list(np.mean(d_t_w_2_tags, 0)) +
                list(np.mean(d_t_w, 0))
            )
        }
        final['event'] = np.mean(n_events)  * 1.
        final['n_orders'] = np.mean(n_orders) * 1.
        yield Record(
            key,
            # events=np.mean(n_events), # n of drivers
            # # free=np.mean(n_free),
            # # verybusy=np.mean(n_verybusy),
            # f_v_2_tags=list(np.mean(f_v_2_tags, 0)),
            # f_v=list(np.mean(f_v, 0)),
            # tags=list(np.mean(tags, 0)),
            # d_t_w_2_tags=list(np.mean(d_t_w_2_tags, 0)),
            # d_t_w=list(np.mean(d_t_w, 0)),
            # n_orders=np.mean(n_orders),

            final_orders=final

        )


def calc_orders_features():
    def inner(groups):
        for key, records in groups:


            # time_windows_in_sec = [30 * 60]#, 60 * 60, 120 * 60, 240 * 60]

            n_events = 0
            unique_drivers = set()
            driver_statuses = Counter()
            driver_statuses_2_reqs = Counter()
            order_taxi_statuses = Counter()
            order_taxi_statuses_2_reqs = Counter()
            events_2_req = Counter()

            for record in records:
                n_events += 1
                unique_drivers.add(record.dbid_uuid)
                driver_statuses[record.driver_status] += 1

                for tag in [
                    'thermobag_tag', 'auto_courier_tag',
                    'walking_courier_tag', 'taxi_express_tag',
                    'taxi_courier_tag', 'expeditor_tag',
                    'multipoints_tag'
                ]:

                    driver_statuses_2_reqs[
                        '{}_{}'.format(
                            record.driver_status, tag
                        )
                    ] += record.get(tag)

                    events_2_req[
                        tag
                    ] += record.get(tag)


                if record.get('order_taxi_status') is not None:
                    order_taxi_statuses[record.order_taxi_status] += 1

                    for tag in [
                        'thermobag_tag', 'auto_courier_tag',
                        'walking_courier_tag', 'taxi_express_tag',
                        'taxi_courier_tag', 'expeditor_tag',
                        'multipoints_tag'
                    ]:
                        order_taxi_statuses_2_reqs[
                            '{}_{}'.format(
                                record.order_taxi_status, tag
                            )
                        ] += record.get(tag)
                # record.driver_status
                # 'free', 'verybusy'

            yield Record(
                key,
                n_unique_drivers=len(unique_drivers),
                n_events=n_events,
                driver_statuses=driver_statuses,
                order_taxi_statuses=order_taxi_statuses,
                order_taxi_statuses_2_reqs=order_taxi_statuses_2_reqs,
                driver_statuses_2_reqs=driver_statuses_2_reqs,
                events_2_req=events_2_req
            )
    return inner


def duplicate_mapper(lag_ts, time_windows_in_sec):
    def _mapper(records):
        for record in records:
            for window_size in time_windows_in_sec:
                if (
                        (record.dttm_utc_1_min < lag_ts)
                        and
                        (record.dttm_utc_1_min >= lag_ts - window_size)
                ):
                    yield Record(
                        record,
                        window_size=window_size,
                        lag_ts=lag_ts,
                        thermobag_tag=(
                                int('thermobag_confirmed' in record.tags)
                                *
                                int('thermobox_option_on' in record.tags)
                        ),
                        auto_courier_tag=int('auto_courier' in record.tags),
                        walking_courier_tag=int(
                            'walking_courier' in record.tags),
                        taxi_express_tag=int('taxi_express' in record.tags),
                        taxi_courier_tag=int('taxi_courier' in record.tags),
                        expeditor_tag=int('expeditor' in record.tags),
                        multipoints_tag=int('multipoints' in record.tags)
                    )
    return _mapper


def orders_duplicate_mapper(lag_ts, time_windows_in_sec):
    def _mapper(records):
        for record in records:
            for window_size in time_windows_in_sec:
                if (
                        (record.timestamp < lag_ts)
                        and
                        (record.timestamp >= lag_ts - window_size)
                ):
                    yield Record(
                        record,
                        window_size=window_size
                    )
    return _mapper


def calc_features(groups):
    for key, records in groups:
        statuses = Counter()

        cands = []
        cands_i = []
        sp_s = []
        look_for_performer_time_s = []

        order_by_req = []

        for record in records:
            ts_created = datetime_2_timestamp(parse_timestring(record.created, 'UTC'))
            ts_seen = datetime_2_timestamp(
                parse_timestring(record.seen, 'UTC')
            ) if record.get('seen') is not None else None

            # print(look_for_performer_time)

            tz = record.tariff_zone
            geo_hash=record.geo_hash

            n_cands = record.n_candidates
            cands.append(n_cands)
            c_i = record.candidate_index
            cands_i.append(c_i)
            sp = record.get('sp')
            if sp is not None:
                sp_s.append(sp)
            look_for_performer_time = (ts_seen - ts_created) if (
                ts_seen is not None) else (60 * 60.)
            look_for_performer_time_s.append(look_for_performer_time)


            ORDER_REQS = ['thermobag_confirmed', 'car_couriers',
             'too_heavy_no_walking_courier', 'cargo_eds', 'cargo_multipoints',
             'thermobox_option_on']
            order_by_req.append([
                float(el in record.unique_special_reqs) for el in ORDER_REQS
            ])
            status = '{}_{}'.format(
                six.ensure_str(record.status or 'None'),
                six.ensure_str(record.taxi_status or 'None')
            )
            statuses[status] += 1

        final = {}
        final['n_cands'] = np.mean(n_cands)  * 1.
        final['c_i'] = np.mean([
            el if (el is not None) else 10 for el in cands_i
        ])  * 1.
        final['sp'] = np.mean(sp_s)  * 1.
        final['look_for_performer_time'] = np.mean([
            el for el in look_for_performer_time_s if el is not None
        ])  * 1.
        final['look_for_performer_time_none_length'] = len([
            el for el in look_for_performer_time_s if el is None
        ]) * 1.

        # final['order_by_req'] =
        for req, value in zip(ORDER_REQS, list(np.mean(order_by_req, 0))):
            final['{}_orders'.format(req)] = value

        for st, v in statuses.items():
            final[st] = float(v)
            final['{}_normed'.format(st)] = v * 1./ sum(statuses.values())

        for order_feat in ORDERS_FEATURES:
            if order_feat not in final:
                final[order_feat] = 0.

        yield Record(
            key,
            final_atlas=final
        )

def concat_different_windows_reducer(entity_type):
    def _inner(groups):
        for key, records in groups:
            final = {}
            for record in records:
                for k, v in record.final_atlas.items():
                    final['{}_{}'.format(k, record.window_size)] = v
                for k, v in record.final_orders.items():
                    final['{}_{}'.format(k, record.window_size)] = v

            res = {
                    entity_type['value']: final,
                    entity_type['key']: key.get(entity_type['key'])
                }

            yield Record(
                ** res
            )
    return _inner
