import json
import collections
import datetime
import hashlib
import itertools
import typing as tp
import sys
from collections import Counter

import six
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind, mannwhitneyu, shapiro, levene
from nile.api.v1 import Record
from nile.api.v1 import clusters
from nile.api.v1 import filters as nf
from nile.api.v1 import extractors as ne
from nile.api.v1 import aggregators as na
from qb2.api.v1 import filters as qf
from qb2.api.v1 import extractors as qe
from qb2.api.v1 import typing as qt

from projects.common.nile import environment
from projects.common.nile.dates import range_selector
# from taxi_pyml.common import geo
# from taxi_pyml.common import time_utils
# from taxi_pyml.common import loaders

# BEGIN_DTTM = '2021-04-16 15:00:00'
# END_DTTM = '2021-04-17 12:00:00'

BEGIN_DTTM = '2021-04-17 15:00:00'
END_DTTM = '2021-04-19 14:00:00'

# GROUP_ID_BUCKETS = {
#     0: 200,
#     1: 800,
# }

GROUP_ID_BUCKETS = {
    0: 600,
    1: 400,
}
EXPERIMENT_MODEL_NAME = 'umlaas_logistics_performer_availability_model_type'

def experiments_mapper(records):
    for record in records:
        position_model = None
        position_threshold = None

        if six.ensure_str(record.consumer) == 'umlaas-logistics-performer-availability':
            kwargs = json.loads(record.get('kwargs') or '{}')
            eats_id = kwargs['unique_request_id']
            matched_list = json.loads(record.get('matched') or '{}')
            matched_model = None
            for item in matched_list:
                if item.get('experiment3_name') == EXPERIMENT_MODEL_NAME and matched_model is None:
                    matched_model = item

            position_model = None if matched_model is None else matched_model.get('position')
            if eats_id is not None:
                yield Record(group_id=position_model,
                             link=record.link,
                        eats_id=eats_id, iso_eventtime=record.iso_eventtime,
                        matched_list=json.dumps(matched_list),
                        kwargs=record.get('kwargs'))

def _get_bucket(order_id, group_id):
    n_buckets = GROUP_ID_BUCKETS.get(int(group_id))
    if n_buckets is None:
        return None
    return int(hashlib.md5(order_id).hexdigest(), 16) % n_buckets

MEASURES = [
    'success_orders',
    'not_found_orders',
    'courier_assigned_orders',
    'cancelled_orders',
    'orders'
]

def _extract_measures(record):
    status = six.ensure_str(record.status)
    courier_not_assigned = record.courier_not_assigned
    if (status == "delivered_finish" or status == "returned_finished") and not courier_not_assigned:
        yield 'success_orders', 1
    if (courier_not_assigned):
        yield 'not_found_orders', 1
    if (not courier_not_assigned):
        yield 'courier_assigned_orders', 1
    yield 'orders', 1

def measures_mapper(records):
    for record in records:
        zone = six.ensure_str(record.zone_id)
        measures = dict(_extract_measures(record))
        if measures:
            yield Record(
                eats_id=record.get('eats_id'),
                tariff_zone=record.get('zone_id'),
                group_id=record.get('group_id'),
                bucket=record.get('bucket'),
                **measures,
            )


from projects.ml_handler_availability.project_config import get_project_cluster

if __name__ == '__main__':
    cluster=get_project_cluster()

    # job = cluster.job()
    # job = job.env(
    #     # bytes_decode_mode='strict',
    #     yt_spec_defaults={'max_failed_job_count': 1000}
    # )
    # job.table(
    #     '//home/logfeller/logs/taxi-exp3-log/1d/{}'.format(
    #         range_selector(BEGIN_DTTM, END_DTTM))
    # ).filter(    qf.compare('timestamp', '>=', six.ensure_binary(BEGIN_DTTM)),
    #     qf.compare('timestamp', '<=', six.ensure_binary(END_DTTM))).map(experiments_mapper).filter(
    #     qf.custom(lambda eats_id: six.ensure_str(eats_id) != "")
    # ).groupby('eats_id').aggregate(
    #     group_id=na.any('group_id')
    # ).project(ne.all(), qe.custom('bucket', _get_bucket, 'eats_id', 'group_id'),).filter(
    #     qf.defined('bucket')
    # ).put('//home/taxi_ml/tmp/nkozlovskaya/logistics_latest_exp_1')
    #
    # job.run()
    #
    job = cluster.job()
    job.table('//home/taxi_ml/tmp/nkozlovskaya/logistics_latest_exp_1').join(
        job.table('//tmp/yql/nkozlovskaya/79926c98-bfa2593-c6c55d12-702dcc07'),
        by_left='eats_id', by_right='order_nr'
    ).put('//home/taxi_ml/tmp/nkozlovskaya/logistics_latest_exp_2')
    job.run()
    #
    # job = cluster.job()
    # job.table('//home/taxi_ml/tmp/nkozlovskaya/logistics_latest_exp_2').join(
    #     job.table(
    #         "//home/taxi/production/replica/postgres/cargo_claims/claims").project(
    #         ne.all(), eats_id=ne.custom(
    #             lambda idempotency_token: six.ensure_str(idempotency_token)[
    #                                       :13])).groupby('eats_id').top(
    #         1, 'created_ts', mode='max'), by='eats_id'
    # ).put('//home/taxi_ml/tmp/nkozlovskaya/logistics_latest_exp_3')
    # job.run()
    # #
    # job = cluster.job()
    # job.table('//home/taxi_ml/tmp/nkozlovskaya/logistics_latest_exp_3').map(
    #     measures_mapper).project(
    #     qe.all(),
    #     qe.unfold_with_total('tariff_zone', 'tariff_zone'),
    # ).filter(
    #     qf.custom(lambda x: x != b'dc28c565829e48cca458b5feb161d5d6',
    #               'corp_client_id')).groupby(
    #     'bucket', 'group_id', 'tariff_zone'
    # ).aggregate(
    #     **{name: na.sum(name) for name in MEASURES}
    # ).put(
    #     '//home/taxi_ml/tmp/nkozlovskaya/logistics_aggregations')
    # job.run()
    #
    job = cluster.job()
    job.table('//home/taxi_ml/tmp/nkozlovskaya/logistics_latest_exp_3').map(
        measures_mapper).project(
        qe.all(),
        qe.unfold_with_total('tariff_zone', 'tariff_zone'),
    ).filter(
        qf.custom(lambda x: x != b'dc28c565829e48cca458b5feb161d5d6',
                  'corp_client_id')).groupby(
        'group_id', 'tariff_zone'
    ).aggregate(
        **{name: na.sum(name) for name in MEASURES}
    ).put(
        '//home/taxi_ml/tmp/nkozlovskaya/logistics_aggregations_tmp')
    job.run()