
import datetime
import geohash as gh
import six
from nile.api.v1 import Record
from nile.api.v1 import aggregators as na
from nile.api.v1 import extractors as ne
from nile.api.v1 import filters as nf
from qb2.api.v1 import filters as qf

from projects.common.nile import filters as pf
from projects.common.nile.dates import range_selector
from projects.common.nile.filters import is_success_taxi_order
from projects.common.decorators import cached_method
from projects.common.time_utils import datetime_2_timestamp, parse_timestring
from projects.efficiency_metrics.nile_blocks.raw_orders import \
    extract_order_proc_mapper
from . import paths

# TODO: add other logistics tariffs?
LOGISTICS_TARIFFS = ['courier', 'express', 'lavka', 'eda']


def flatten_field_mapper(field_from, field_to):
    def mapper(records):
        for record in records:
            tmp = record.to_dict()
            if field_from in tmp:
                del tmp[field_from]
            for value in record.get(field_from):
                tmp[field_to] = value
                yield Record(**tmp)

    return mapper


def order_proc_mapper(records):
    # for key, records in groups:
    for record in records:
        doc = record.get('doc', {})

        order_id = doc.get('_id')


        doc_order = doc.get('order', {})

        doc_class_list = doc_order.get('request', {}).get('class')

        if (
                (doc_class_list is not None) and
                ('express' not in doc_class_list) and
                ('courier' not in doc_class_list) and
                ('lavka' not in doc_class_list) and
                ('eda' not in doc_class_list)
        ):
            continue

        # ["express", "courier"]

        tariff_zone = doc_order.get('nz', '')
        if tariff_zone == '':
            continue
        created = doc.get('created')
        status = doc_order['status']
        taxi_status = doc_order['taxi_status']


        geopoint=doc_order.get(
            'request', {}
        ).get('source', {}).get('geopoint')

        performer = doc['performer']


        candidates = doc.get('candidates')
        n_candidates = len(candidates)



        sp = doc_order.get('request', {}).get('sp')

        virtual_tariffs = doc_order.get('virtual_tariffs')

        special_order_reqs = {}
        if virtual_tariffs is not None:
            for element in virtual_tariffs:
                tariff = element['class']
                tariff_reqs=[]
                for item in element['special_requirements']:
                    tariff_reqs.append(item['id'])
                special_order_reqs[tariff] = tariff_reqs

        # special_order_reqs

        unique_special_reqs = list(set(sum(special_order_reqs.values(), [])))

        if geopoint is not None:
            yield Record(
                order_id=order_id,
                # geopoint=geopoint,
                lat=geopoint[1],
                lon=geopoint[0],
                geo_hash=gh.encode(geopoint[1], geopoint[0], precision=5),
                # special_order_reqs=special_order_reqs,
                unique_special_reqs=unique_special_reqs,
                sp=sp,
                doc_class_list=doc_class_list,
                # performer=performer,
                candidate_index=performer.get('candidate_index'),
                seen=performer.get('seen'),
                created=created,
                status=status,
                taxi_status=taxi_status,
                n_candidates=n_candidates,
                tariff_zone=tariff_zone,
                candidates=candidates,
                updated=record.updated

                # d = {'ar':1, 'cr':0}
            )


class DataContext:
    def __init__(
            self,
            job,
            begin_dttm: datetime.datetime,
            end_dttm: datetime.datetime,
            sources=None,
    ):
        sources = sources or ['orders']

        self._cache = dict()
        self._job = job
        self.begin_dttm = begin_dttm
        self.end_dttm = end_dttm
        self.time_filter = lambda ts_field: pf.dttm_between(
            begin_dttm, end_dttm, ts_field,
        )

        self._sources = sources
        self._source_logs_map = {
            'orders': self.get_orders,
            'sessions': self.get_sessions,
            'raw_orders': self.get_raw_orders,
            'atlas_drivers': self.get_atlas_drivers,
            'light_order_proc': self.get_light_order_proc
            # 'subsidies': self.get_subsidies,
        }

    def get_job(self):
        return self._job


    @cached_method
    def get_sources_logs(self, sources=None):
        sources = sources or self._sources
        sources_logs = [
            self._source_logs_map.get(sources)() for sources in sources
        ]
        return self._job.concat(*sources_logs)

    @cached_method
    def get_drivers(self):
        return (
            self._job.table(paths.DM_EXECUTOR_PROFILE_ACT)
            .filter(
                qf.defined(
                    'unique_driver_id',
                    'executor_profile_id',
                    'park_taximeter_id',
                ),
            )
            .project(
                'unique_driver_id',
                'park_taximeter_id',
                driver_uuid='executor_profile_id',
            )
            .unique('unique_driver_id', 'driver_uuid', 'park_taximeter_id')
        )

    @cached_method
    def get_drivers_hist(self):
        result = self._job.table(paths.DRIVERS_HIST).project(
            'utc_hired_dttm', 'park_id', 'effective_from_dttm',
            'driver_phones', 'driver_license_normalized', 'driver_license',
            'driver_id', 'taximeter_park_id', 'driver_uuid'
        ).filter(
            qf.defined('driver_phones')
        ).map(flatten_field_mapper('driver_phones', 'driver_phone'))
        return result

    @cached_method
    def get_orders(self, success_only=False):
        # TODO: провреиить, что во всех даат конткестах фильтруется по utc, а не local

        filters = [qf.defined('utc_order_dttm')]
        if success_only:
            filters.append(qf.nonzero('success_order_flg'))
        result = self._job.table(paths.DM_ORDER.format(
            range_selector(self.begin_dttm, self.end_dttm, '%Y-%m')
        )).filter(
            *filters
        ).filter(
            self.time_filter('utc_order_dttm'),
            nf.custom(
                lambda x: six.ensure_str(x) in LOGISTICS_TARIFFS,
                'order_tariff',
            ),
        ).project(
            ne.all(),
            log_source=ne.const('orders'),
            timestamp=ne.custom(
                lambda utc_order_dttm:
                datetime_2_timestamp(
                    parse_timestring(utc_order_dttm, 'UTC')
                )
            )
        )
        return result

    @cached_method
    def get_raw_orders(self):
        return(
            self._job.table(
                paths.ORDER_PROC.format(
                    range_selector(self.begin_dttm, self.end_dttm, '%Y-%m-01')
                )
            ).project(
                'id', 'order_created', 'doc'
            ).filter(
                self.time_filter('order_created')
            ).map(
                extract_order_proc_mapper,
                intensity='large_data'
            )
            .project(
                ne.all(),
                log_source=ne.const('raw_orders'),
            )
        )

    @cached_method
    def get_order_payment_info(self):
        return (
            self._job.table(
                paths.ORDER_PAYMENT_INFO.format(
                    range_selector(self.begin_dttm, self.end_dttm, '%Y-%m-01'),
                ),
                ignore_missing=True,
            ).project(
                'order_id', 'user_cost_vat', 'order_tariff', 'utc_order_created_dttm'
            ).filter(
                qf.defined('utc_order_created_dttm'),
                self.time_filter('utc_order_created_dttm'),
                # nf.custom(
                #     lambda x: six.ensure_str(x) in LOGISTICS_TARIFFS,
                #     'order_tariff',
                # ),
            )#.project(
            #     ne.all(),
            #     # log_source=ne.const('orders'),
            #     timestamp=ne.custom(
            #         lambda utc_order_dttm:
            #         datetime_2_timestamp(
            #             parse_timestring(utc_order_dttm, 'UTC')
            #         )
            #     )
            # )
        )

    @cached_method
    def get_sessions(self):
        return (
            self._job.table(
                paths.DRIVER_SESSIONS.format(
                    range_selector(self.begin_dttm, self.end_dttm, '%Y-%m-%d'),
                ),
                ignore_missing=True,
            )
            .project(
                'order_id',
                'old_unique_driver_id',
                'lcl_available_tariff_class_code_list',
                'lcl_valid_from_dttm',
                'enabled_tariff_class_code_list',
                'distance_km',
                'utc_valid_from_dttm',
                'duration_sec',
                'park_taximeter_id',
                'tariff_geo_zone_code',
                'agglomeration_geo_node_id',
                driver_uuid='executor_profile_id',
                status='executor_status_code',
            )
            .project(
                ne.all(),
                utc_session_dt=ne.custom(
                    lambda x: x[:10], 'utc_valid_from_dttm',
                ),
            )
            .filter(
                qf.defined('driver_uuid'),
                qf.defined('park_taximeter_id'),
                qf.defined('utc_valid_from_dttm'),
                self.time_filter('utc_valid_from_dttm'),
                intensity='data',
            )
            .project(
                ne.all(),
                dbid_uuid=ne.custom(
                    lambda x, y: '{}_{}'.format(x, y),
                    'park_taximeter_id', 'driver_uuid'
                ),
                timestamp=ne.custom(
                    lambda utc_valid_from_dttm: datetime_2_timestamp(
                        parse_timestring(utc_valid_from_dttm, 'UTC'),
                    ),
                ),
                utc_date=ne.custom(
                    lambda x: x[:10], 'utc_valid_from_dttm'
                )
            )
            # .join(
            #     self.get_drivers(),
            #     by=['driver_uuid',s 'park_taximeter_id'],
            #     type='left',
            # )
            .project(
                ne.all(),
                log_source=ne.const('sessions'),
            )
        )

    @cached_method
    def get_dm_subvention_transaction_log(self):
        return (
            self._job.table(
                paths.DM_SUBVENTION_TRANSACTION_LOG.format(
                    range_selector(self.begin_dttm, self.end_dttm, '%Y-%m-01'),
                ),
                ignore_missing=True,
            ).filter(
                nf.custom(
                    lambda x: x != 'subsidy_discount_payback',
                    'detailed_product_name'
                )
                # detailed_product_name='subsidy_discount_payback'
            ).project(
                "transaction_id",
                "order_id",
                "order_id_list",
                "park_id",
                "product_name",
                "subsidy_commission_value",
                "subsidy_value",
                "subsidy_w_commission_value",
                "tariff_geo_zone_name",
                "taximeter_order_id",
                "taximeter_park_id",
                "unrealized_subsidy_commission_value",
                "utc_transaction_dttm",
                "version_seq",
                ### TODO:
                # "billing_status_name"
                "billing_type",
                # "ccy_code"
                # "ccy_rate"
                # "currency_code"
                # "detailed_product_name"
                # "discount_value"
                # "dmd_value"
                # "dms_rate"
                # "dms_value"
                # "driver_fix_comission_for_fraud_w_vat_amt"
                # "driver_fix_comission_for_fraud_wo_vat_amt"
                # "driver_fix_comission_w_vat_amt"
                # "driver_fix_comission_wo_vat_amt"
                # "driver_uuid"
                "etl_updated",
                "event_id",
                # "holded_discount_value"
                # "holded_dmd_value"
                # "holded_dms_value"
                # "holded_subsidy_commission_value"
                # "holded_subsidy_value"
                # "holded_subsidy_w_commission_value"
                # "holded_subvention_value"

                "rule_id"

                # "subvention_value"

                # "transaction_ccy_value"
                # "transaction_dms_value"
                # "transaction_type"
                # "transaction_value"

            )
        )

    @cached_method
    def get_corp_clients(self, get_last_dttm_flag):
        tmp = self._job.table(
            paths.DIM_CORP_CLIENTS_HIST
        ).project('effective_to_dttm', 'client_id', 'name')

        if get_last_dttm_flag:
            tmp = tmp.groupby('client_id').aggregate(
                name=na.last('name', by='effective_to_dttm')
            )

        return tmp


    @cached_method
    def get_atlas_drivers(self, dttm_needed):
        # TODO: fix code here

        return self._job.table(paths.ATLAS.format(
            f'{{{",".join(dttm_needed)}}}'
        ), ignore_missing=True).project(
            # 1614865740 "2021-03-04 16:50:01" 1614865801
            'car_classes', 'driver_status', 'lat', 'lon', 'order_taxi_status',
            'tags', 'tariff_zone', 'timestamp' 'tx_status', 'dttm_utc_1_min',
            'iso_eventtime',
            dbid_uuid=ne.custom(
                lambda x, y: '{}_{}'.format(x, y), 'park_db_id', 'driver_uuid'
            ),
            geo_hash=ne.custom(lambda x, y: gh.encode(x, y, precision=5), 'lat', 'lon')
        )

    @cached_method
    def get_light_order_proc(self):
        print(paths.ORDER_PROC_DAILY.format(
                    range_selector(self.begin_dttm, self.end_dttm, '%Y-%m-%d')
        ))
        return (
            self._job.table(
                paths.ORDER_PROC_DAILY.format(
                    range_selector(self.begin_dttm, self.end_dttm, '%Y-%m-%d'),
                ), ignore_missing=True,
            ).project(
                ne.all(),
                # ts=ne.custom(lambda x: , 'updated')
            ).filter(
                self.time_filter('updated')
            ).map(
                order_proc_mapper, intensity='large_data'
            )
        )


