import typing as tp

from nile.api.v1 import (
    extractors as ne,
    datetime as nd,
    utils as nu,
    extended_schema,
    multischema,
    Record,
    Job,
    stream as nstream,
)
from qb2.api.v1 import (
    extractors as qe,
    typing as qt,
)

from maps.wikimap.stat.libs.common.lib import geobase_region

from .calculate_metrics import (
    calculate_feedback_metrics,
    calculate_users_metrics,
)

from .common import (
    WindowsConfig,
    make_datetime,
)
from .constants import (
    MEMORY_LIMIT,
    EXPERIMENTS_DIMENSIONS,
    NULL_KEY,
    FEEDBACK_METRICS_TYPE,
    FORMS_METRICS_TYPE,
    USERS_METRICS_TYPE,
    EXPERIMENTS_METRICS_TYPE,
)
from .forms_metrics import (
    prepare_bebr_data,
    calculate_daily_forms_metrics,
    aggregate_forms_metrics_in_windows,
)
from .nmaps_data import join_nmaps_data
from .original_task import make_extractor, OriginalTask


FEEDBACK_TASK_TABLE_PATH: str = "home/maps/core/nmaps/dynamic/replica/fbapi/feedback_task"
FEEDBACK_TASK_CHANGES_TABLE_PATH: str = "home/maps/core/nmaps/dynamic/replica/fbapi/feedback_task_changes"
REGIONS_TABLE_PATH: str = "home/maps/core/nmaps/analytics/geo-data/major_regions_map"

NMAPS_FEEDBACK_DB_TABLE = "home/maps/core/nmaps/analytics/feedback/db/feedback_latest"
NMAPS_FEEDBACK_FROM_FBAPI_PATH = "home/maps/core/nmaps/analytics/feedback/fbapi"

TEST_IDS_TABLE = "home/maps/core/nmaps/analytics/feedback_metrics/meta/test_id"


def is_in_progress(status: str) -> bool:
    return status == "in_progress"


def is_need_info(status: str) -> bool:
    return status == "need_info"


def is_published(status: str) -> bool:
    return status == "published"


def is_accepted(status: str) -> bool:
    return status == "accepted"


def is_rejected(status: str) -> bool:
    return status == "rejected"


def is_post_resolution_status(status: str) -> bool:
    return status in ("published", "accepted", "rejected")


def extract_regions(feedback_table: nstream.Stream, regions_table: nstream.Stream) -> nstream.Stream:
    regions_map: nstream.Stream = regions_table.project(
        "region_tree",
        "region_id",
        region_name=ne.custom(lambda name: name.replace("/", "\t"), "region_name").with_type(str))

    return feedback_table.project(
        ne.all(),
        region_id=make_extractor(OriginalTask.extract_region_id),
        files=geobase_region.FILES,
        memory_limit=geobase_region.GEOBASE_JOB_MEMORY_LIMIT
    ).join(
        regions_map,
        by="region_id",
        memory_limit=MEMORY_LIMIT)


def get_sprav_feedback_type(form_id: str, question_id: str, answer_id: str) -> str:
    """
    Calculate type of feedback with form_id='organization' based on question_id and answer_id.

    Code is taken from here:
    https://a.yandex-team.ru/arc/trunk/arcadia/maps/wikimap/stat/nmaps_feedback/lib/sprav_data.py?rev=r8054092#L12
    """
    if form_id != "organization":
        return NULL_KEY

    if question_id == "add_object":
        if answer_id == "entrance":
            return "poi_entrance"
        elif answer_id == "organization":
            return "poi"
    elif question_id in ("closed", "opened"):
        return "poi_status"
    elif question_id in ("entrance_problem", "organization_entrance", "wrong_entrance"):
        return "poi_entrance"
    elif question_id == "other" and answer_id == "comment":
        return "poi"

    # Unknown combination.
    return "poi"


def extract_dimensions(feedback_table: nstream.Stream) -> nstream.Stream:
    return feedback_table.project(
        ne.all(exclude=["region_id", "region_tree"]),
        form_id=make_extractor(OriginalTask.extract, b"form_id"),
        client_id=make_extractor(OriginalTask.extract_client_id),
        form_type=make_extractor(OriginalTask.extract, b"form_type"),
        form_context_id=make_extractor(OriginalTask.extract, b"form_context_id"),
        client_context_id=make_extractor(OriginalTask.extract_client_context_id),
        question_id=make_extractor(OriginalTask.extract, b"question_id"),
        answer_id=make_extractor(OriginalTask.extract, b"answer_id"),
        type=ne.custom(get_sprav_feedback_type, "form_id", "question_id", "answer_id").with_type(str),
        some_user_id=make_extractor(OriginalTask.extract_some_user_id),
    )


def construct_task_history(feedback_changes_table: nstream.Stream) -> nstream.Stream:
    """
    Scans through feedback history and calculates timestamps used in metrics:
    resolved_at, published_at, need_info_at, un_need_info_at.

    The most tricky timestamp is resolved_at.
    1. resolved_at is equal to the timestamp of the last status update to 'accepted' or 'rejected'.
    2. If there are no status updates to 'accepted' or 'rejected', but there is a 'published', then
    we assume that there was an 'accepted' in the same moment right before 'published'.
    3. But if feedback is reopened (transfered from 'accepted'/'rejected' to 'in_progress'),
    we ignore history before that moment.

    A couple of examples:
    1. publised at 1 => resolved_at == 1
    2. accepted at 1, published at 2 => resolved_at == 1
    3. accepted at 1, in_progress at 2 => resolved_at == None
    5. accepted at 1, in_progress at 2, accepted at 3 => resolved_at == 3
    5. accepted at 1, in_progress at 2, published at 3 => resolved_at == 3
    5. accepted at 1, in_progress at 2, accepted at 3, published at 4 => resolved_at == 3
    """
    @nu.with_hints({
        "id": str,
        "resolved_at": qt.Optional[qt.Unicode],
        "published_at": qt.Optional[qt.Unicode],
        "need_info_at": qt.Optional[qt.Unicode],
        "un_need_info_at": qt.Optional[qt.Unicode],
    })
    def history_reducer(groups: tp.Iterable[tp.Tuple[Record, tp.Iterable[Record]]]) -> tp.Generator[Record, None, None]:
        for key, records in groups:
            resolved_at: tp.Optional[str] = None
            published_at: tp.Optional[str] = None

            need_info_at: tp.Optional[str] = None
            un_need_info_at: tp.Optional[str] = None

            for record in records:
                updated_at: str = record.get("created_at")
                task_status: str = record.get("status")

                if is_in_progress(task_status):
                    resolved_at = None
                    published_at = None

                if is_accepted(task_status) or is_rejected(task_status):
                    resolved_at = updated_at
                if is_published(task_status):
                    published_at = updated_at
                    if resolved_at is None:
                        resolved_at = updated_at

                if is_need_info(task_status):
                    need_info_at = updated_at
                    un_need_info_at = None
                else:
                    if need_info_at is not None and un_need_info_at is None:
                        un_need_info_at = updated_at

            yield Record(
                id=key["task_id"],
                resolved_at=resolved_at,
                published_at=published_at,
                need_info_at=need_info_at,
                un_need_info_at=un_need_info_at)

    return feedback_changes_table.project(
        "id", "status", "created_at", "task_id"
    ).groupby("task_id").sort("created_at").reduce(
        history_reducer,
        memory_limit=MEMORY_LIMIT
    )


def fix_missing_history(feedback_table: nstream.Stream) -> nstream.Stream:
    """
    Sometimes feedback has missing history, i.e. it has resolved status, but
    feedback_task_changes doesn't have a record about this status change.
    Very few feedback tasks have missing history. Majority of such feedback
    was created before Sep 2020 bacause task history wasn't supported at that time.
    In such cases we assume that resolved/published time equals to created time.
    https://st.yandex-team.ru/NMAPS-14048
    """
    @nu.with_hints(output_schema=extended_schema())
    def fix_history_mapper(records: tp.Iterable[Record]) -> tp.Generator[Record, None, None]:
        for record in records:
            resolved_at: tp.Optional[str] = record.get("resolved_at")
            published_at: tp.Optional[str] = record.get("published_at")
            need_info_at: tp.Optional[str] = record.get("need_info_at")
            created_at: tp.Optional[str] = record.get("created_at")

            task_status: str = record.get("status")
            if is_post_resolution_status(task_status) and resolved_at is None:
                resolved_at = created_at
            if is_published(task_status) and published_at is None:
                published_at = created_at
            if is_need_info(task_status) and need_info_at is None:
                need_info_at = created_at

            yield Record(
                record,
                resolved_at=resolved_at,
                published_at=published_at,
                need_info_at=need_info_at)

    return feedback_table.map(
        fix_history_mapper,
        memory_limit=MEMORY_LIMIT
    )


def annotate_with_dates(feedback_table: nstream.Stream, windows_config: WindowsConfig) -> nstream.Stream:
    """Annotate records with fielddate and window_days

    We calculate metrics for all dates at once in a common workflow. But for this we need to have a separate dataset
    for each 'fielddate' and window. In other words, each record must have a 'fielddate' and window fields and be
    duplicated for the desired dates. We could just add each fielddate for each record from the database. But it will
    increase the amount of data len(fielddates) * len(windows) times. And we will not use most of it for different
    dates. However, the execution time will suffer significantly.

    Instead for each pair (record, fielddate) we check if it will be used at all beforehand.

    Our three kinds of metrics take into account the appropriate kind of feedback records for each date:
        * actions metrics: with 'created' and 'published' actions on that date;
        * unresolved metrics: currently unresolved feedback -- the date is after created and before resolved;
        * resolution metrics: with 'resolved' action on that date.

    First of all, we filter out all irrelevant feedback -- that was created after or published before the date range of
    interest. Then, for each (record, fielddate) pair, we check whether the record has an action with it on that date.
    If so, we take this pair. If not, we check whether the record was in an 'unresolved' state on that date and take it
    if it was. Otherwise, we ignore this pair.
    """

    @nu.with_hints(output_schema=extended_schema(
        fielddate=str,
        window_start=str,
        window_end=str,
        window_days=qt.Int32,
    ))
    def annotate_with_dates_mapper(records: tp.Iterable[Record]) -> tp.Generator[Record, None, None]:
        for record in records:
            created_at: tp.Optional[nd.Datetime] = make_datetime(record.get("created_at"))
            resolved_at: tp.Optional[nd.Datetime] = make_datetime(record.get("resolved_at"))
            published_at: tp.Optional[nd.Datetime] = make_datetime(record.get("published_at"))

            if created_at > windows_config.right_margin:
                continue

            if published_at is not None and published_at < windows_config.left_margin:
                continue

            for fielddate, window, window_days in windows_config.windows():
                do_yield: bool = False
                for timepoint in (created_at, resolved_at, published_at):
                    if window.includes(timepoint):
                        do_yield = True
                        break
                else:
                    do_yield = resolved_at is not None and created_at < window.end and resolved_at >= window.end

                if do_yield:
                    yield Record(
                        record,
                        fielddate=fielddate,
                        window_start=window.start.date().isoformat(),
                        window_end=window.end.date().isoformat(),
                        window_days=window_days)

    return feedback_table.map(annotate_with_dates_mapper)


def prepare_experiments_data(
        feedback_table: nstream.Stream,
        test_ids_table: nstream.Stream) -> nstream.Stream:
    test_ids = test_ids_table.project("test_id")

    return feedback_table.project(
        ne.all(),
        qe.unfold("test_id", "test_ids").with_type(str),
        test_ids=make_extractor(OriginalTask.extract_test_ids).hide()
    ).join(
        test_ids,
        by="test_id"
    )


def split_metrics_by_weeks(feedback_metrics: nstream.Stream, dates: tp.List[str]) -> tp.Iterable[nstream.Stream]:
    output_schema = multischema(*[extended_schema() for _ in range(len(dates))])

    @nu.with_hints(outputs_count=len(dates), output_schema=output_schema)
    def split_mapper(records: tp.Iterable[Record], *outputs: tp.Callable[[Record], None]) -> None:
        for record in records:
            week: int = dates.index(record.get("fielddate")) // 7
            outputs[week](record)

    metrics_by_weeks: nstream.Stream = feedback_metrics.map(split_mapper)

    # split_metrics_by_weeks have to always return an iterable, but
    # map operation with 1 output does not produce an iterable.
    if len(dates) == 1:
        return [metrics_by_weeks]
    return metrics_by_weeks


def format_output_tables_names(dates: tp.List[str]) -> tp.List[str]:
    week_names = []
    for first_day in range(0, len(dates), 7):
        week_names.append(dates[first_day])
    return week_names


def save_weekly_tables(metrics, dates, result_path, metric_type):
    """
    Saves all metrics weekly tables with appropriate names to {result_path}/{metric_type}:
    metrics[0] table is stored into {result_path}/{metric_type}/{dates[0]}
    metrics[1] table is stored into {result_path}/{metric_type}/{dates[1]}
    ...
    and so on.
    """
    metrics_by_weeks: tp.Iterable[nstream.Stream] = split_metrics_by_weeks(metrics, dates)
    output_tables_names: tp.List[str] = format_output_tables_names(dates)

    for weekly_metrics, table_name in zip(metrics_by_weeks, output_tables_names):
        table_path = f"{result_path}/{metric_type}/{table_name}"
        weekly_metrics.put(table_path)


def make_job(
    job: Job,
    from_date: str,
    to_date: str,
    nmaps_from_fbapi_dump_date: str,
    result_path: str,
    windows_days: tp.List[int],
    *,
    use_bebr: bool
) -> None:
    feedback_table: nstream.Stream = job.table(FEEDBACK_TASK_TABLE_PATH).label("feedback_table")
    feedback_changes_table: nstream.Stream = job.table(FEEDBACK_TASK_CHANGES_TABLE_PATH).label("feedback_changes_table")
    regions_table: nstream.Stream = job.table(REGIONS_TABLE_PATH).label("regions_table")

    nmaps_feedback_db_table: nstream.Stream = job.table(NMAPS_FEEDBACK_DB_TABLE).label("nmaps_feedback_db_table")
    nmaps_from_fbapi_feedback_table: nstream.Stream = job.table(
        "/".join((NMAPS_FEEDBACK_FROM_FBAPI_PATH, nmaps_from_fbapi_dump_date))
    ).label("nmaps_from_fbapi_feedback_table")

    test_ids_table: nstream.Stream = job.table(TEST_IDS_TABLE).label("test_ids_table")

    windows_config: WindowsConfig = WindowsConfig(from_date, to_date, windows_days=windows_days)

    full_feedback_data: nstream.Stream = feedback_table.call(
        extract_regions,
        regions_table
    ).label("feedback_with_regions_table").call(
        extract_dimensions
    ).label("feedback_with_dimensions_table").join(
        construct_task_history(feedback_changes_table).label("feedback_history"),
        by="id",
        type="left",
        memory_limit=MEMORY_LIMIT
    ).label("feedback_with_history_table").call(
        fix_missing_history
    ).label("feedback_with_fixed_history_table").call(
        join_nmaps_data,
        nmaps_feedback_db_table,
        nmaps_from_fbapi_feedback_table
    ).label("feedback_with_nmaps_data_table").call(
        annotate_with_dates,
        windows_config
    ).label("full_feedback_data")

    feedback_metrics = full_feedback_data.call(
        calculate_feedback_metrics
    ).label("feedback_metrics")

    users_metrics = full_feedback_data.call(
        calculate_users_metrics
    ).label("users_metrics")

    experiments_metrics = full_feedback_data.call(
        prepare_experiments_data,
        test_ids_table
    ).call(
        calculate_feedback_metrics,
        EXPERIMENTS_DIMENSIONS
    ).label("experiments_metrics")

    dates: tp.List[str] = windows_config.dates()
    save_weekly_tables(feedback_metrics, dates, result_path, FEEDBACK_METRICS_TYPE)
    save_weekly_tables(users_metrics, dates, result_path, USERS_METRICS_TYPE)
    save_weekly_tables(experiments_metrics, dates, result_path, EXPERIMENTS_METRICS_TYPE)

    if use_bebr:
        bebr_data: nstream.Stream = prepare_bebr_data(job, windows_config)
        forms_metrics: nstream.Stream = calculate_daily_forms_metrics(bebr_data).call(
            aggregate_forms_metrics_in_windows, windows_config
        )
        save_weekly_tables(forms_metrics, dates, result_path, FORMS_METRICS_TYPE)
