import logging

from datetime import datetime, timedelta

from sandbox import sdk2
from sandbox.sandboxsdk import environments
from sandbox.projects.common.binary_task import deprecated as binary_task
from sandbox.projects.maps.common import juggler_alerts


RECORD_WITH_EXPERIMENTS = 0
RECORD_WITHOUT_EXPERIMENTS = 1

REPORTS_DIR = "//home/navigator-user-report/daily-reports"
ACTIVE_USERS_DIR = "//home/maps/statistics/navi/production/active_users"

YT_POOL = "maps-core-navi-statistics"

def extract_experiments(test_buckets):
    return ','.join([str(item['id']) for item in test_buckets])

def mapper(row):
    if "TestBuckets" in row and "UUID" in row:
        row["ExperimentInfo"] = RECORD_WITH_EXPERIMENTS
        yield row
    elif "uuid" in row:
        row["ExperimentInfo"] = RECORD_WITHOUT_EXPERIMENTS
        row["UUID"] = row["uuid"]
        yield row


def reducer(key, rows):
    experiments = None
    for row in rows:
        if "TestBuckets" in row:
            experiments = extract_experiments(row["TestBuckets"])
        elif "report" in row:
            row["experiment_ids"] = experiments
            del row["ExperimentInfo"]
            del row["UUID"]
            yield row


class NavigatorSortUserReports(binary_task.LastBinaryTaskRelease, juggler_alerts.TaskJugglerReportWithParameters):
    class Requirements(sdk2.Task.Requirements):
        environments = [
            environments.PipEnvironment("yandex-yt"),
        ]

    class Parameters(sdk2.Task.Parameters):

        yav_secret = sdk2.parameters.YavSecret("YAV secret (sec-***)")
        compression_codec = sdk2.parameters.String("Compression codec", default='zstd_9')
        ext_params = binary_task.binary_release_parameters(stable=True)
        juggler_report_is_enabled = True
        juggler_host_name = 'sandbox.NavigatorSortUserReports'
        juggler_service_name = 'task_failure'

    def get_yt_client(self):
        from yt.wrapper import YtClient

        return YtClient(
            proxy="hahn.yt.yandex.net",
            token=self.Parameters.yav_secret.data()["token"])

    def do_processing(self, yt, report_path, active_users_path):
        attrs = {}
        schema = yt.get_attribute(report_path, 'schema', default={})
        attrs['schema'] = schema

        with yt.TempTable(attributes=attrs) as tmp:
            logging.info("Adding experiments to {}".format(report_path))
            yt.run_map_reduce(
                mapper,
                reducer,
                [report_path, active_users_path],
                tmp,
                reduce_by=["UUID"],
                sort_by=["UUID", "ExperimentInfo"],
                spec={
                    "data_size_per_map_job": 1 * 1024 * 1024 * 1024,
                    "reduce_job_io": {"table_writer": {"max_row_weight": 128 * 1024 * 1024}},
                    "pool" : YT_POOL
                }
            )

            logging.info("Sorting {}".format(report_path))
            yt.run_sort(
                tmp,
                sort_by="uuid",
                spec={"pool" : YT_POOL})

            logging.info("Compressing {}".format(report_path))
            yt.transform(
                tmp,
                report_path,
                compression_codec=self.Parameters.compression_codec,
                spec={"pool" : YT_POOL})

            yt.set_attribute(report_path, "_with_experiments", True)

    def on_save(self):
        super(NavigatorSortUserReports, self).on_save()

    def on_execute(self):
        super(NavigatorSortUserReports, self).on_execute()
        logging.basicConfig(level=logging.INFO)

        from yt.wrapper import ypath_join, YtOperationFailedError

        yt = self.get_yt_client()

        reports = yt.list(REPORTS_DIR)
        logging.info("Found {} reports: {}".format(len(reports), reports))

        now = datetime.utcnow()
        today = now.strftime("%Y-%m-%d")
        yesterday = (now - timedelta(days=1)).strftime("%Y-%m-%d")
        logging.info("Today is {}".format(today))

        last_error = None  # Memorize last error to throw it in the end

        for report_date in reports:
            if report_date == today:
                logging.info("Skipping today's report table")
                continue

            report_path = ypath_join(REPORTS_DIR, report_date)
            active_users_path = yt.TablePath(ypath_join(ACTIVE_USERS_DIR, report_date), columns=["TestBuckets", "UUID"])
            if report_date == yesterday and not yt.exists(active_users_path):
                logging.info("Active users for this date doesn't exist, skipping")
                continue

            logging.info(report_path)

            file_type = yt.get(report_path + "/@type")
            if file_type != "table":
                continue

            with_experiments = yt.has_attribute(report_path, '_with_experiments')
            if with_experiments:
                logging.info("Table already has experiment ids, skipping")
                continue

            def has_experiments(schema):
                for column in schema:
                    if 'name' in column and column['name'] == 'experiment_ids':
                        return True
                return False

            schema = yt.get_attribute(report_path, 'schema', default={})
            if not has_experiments(schema):
                logging.info("Table schema doesn't have experiment_ids column, skipping it")
                continue

            try:
                self.do_processing(yt, report_path, active_users_path)
                break  # process one table at a time
            except YtOperationFailedError as e:
                logging.error(e)
                last_error = e
        if last_error:
            raise last_error
