import datetime
import hashlib
import six
from typing import AnyStr, Any, Union

from nile.api.v1 import filters as nf
from qb2.api.v1 import filters as qf

from projects.common import time_utils


def between(field: Any, begin: Any, end: Any):
    return qf.and_(qf.compare(field, '>=', begin), qf.compare(field, '<', end))


def ts_between(
        begin: Union[datetime.datetime, int, float],
        end: Union[datetime.datetime, int, float],
        ts_field: AnyStr = 'timestamp',
):
    if isinstance(begin, datetime.datetime):
        begin = time_utils.datetime_2_timestamp(begin)
    if isinstance(end, datetime.datetime):
        end = time_utils.datetime_2_timestamp(end)
    return between(ts_field, begin, end)


def dttm_between(
        begin: datetime.datetime,
        end: datetime.datetime,
        dt_field: str,
        format: str = '%Y-%m-%d %H:%M:%S',
):
    """
    Checks if the dt_field value is between begin and end datetimes
    :param begin: datetime
    :param end: datetime
    :param dt_field: str, name of the field with dttm
    :param format: str
    :return:
    """
    return qf.custom(
        lambda val: begin
        <= time_utils.parse_datetime(value=six.ensure_str(val), format=format)
        < end,
        dt_field,
    )


def is_success_taxi_order(
        status_field: AnyStr = 'status',
        taxi_status_field: AnyStr = 'taxi_status',
):
    return nf.custom(
        # bytes because of YT
        lambda fst, snd: fst == b'finished' and snd == b'complete',
        status_field,
        taxi_status_field,
    )


def sample_by_field(
        field_name: AnyStr, sample_prob: float, encoding: str = 'utf-8',
):
    """
    Get filter to sample random Records by sample_prob
    :param field_name: str field name to compute hash
    :param sample_prob: float prob to pass through filter
    :param encoding: encoding for six.ensure_binary
    :return: nile filter
    """
    assert 0 <= sample_prob <= 1

    def filter_func(value: AnyStr):
        hash_value = int(
            hashlib.md5(
                six.ensure_binary(value, encoding=encoding),
            ).hexdigest(),
            16,
        )
        return hash_value % 100 < 100 * sample_prob

    return nf.custom(filter_func, field_name)