from datetime import date, timedelta
import logging

from sandbox.sandboxsdk.parameters import (
    SandboxStringParameter,
    SandboxIntegerParameter,
    SandboxFloatParameter,
)

from sandbox.sandboxsdk.errors import SandboxTaskFailureError
from sandbox.projects.common.userdata import sample_base_task, util, mr_base_task


class LastDate(SandboxStringParameter):
    name = 'last_date'
    description = 'Last date of sessions period, YYYYMMDD (optional, will use last available period if empty):'
    required = False
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class DaysPerPeriod(SandboxIntegerParameter):
    name = 'days_per_period'
    description = 'Days per period:'
    required = True
    default_value = 30
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class BaseFrac(SandboxFloatParameter):
    name = 'frac'
    description = 'Base fraction sessions to collect:'
    required = True
    default_value = 0.00005
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class SampleUserCountersSessions(sample_base_task.Task):
    """
    Sample user_sessions/watch_log to test user_counters/update.py
    """

    type = 'SAMPLE_USER_COUNTERS_SESSIONS'

    input_parameters = util.smart_join_params(
        sample_base_task.Task.input_parameters,
        DaysPerPeriod,
        LastDate,
        BaseFrac
    )

    def get_strat_data_dir(self):
        return 'quality/user_counters/scripts/testing'

    def do_get_dates_context(self):
        num_days = self.ctx[DaysPerPeriod.name]
        if self.ctx.get(LastDate.name):
            end_date = util.str2date(self.ctx[LastDate.name])
        else:
            end_date = date.today() - timedelta(3)
        start_date = end_date - timedelta(num_days - 1)
        if end_date >= date.today():
            raise SandboxTaskFailureError("date {} is in the future".format(util.date2str(end_date)))

        if num_days <= 3:
            raise SandboxTaskFailureError("period length cannot be <= 3 days")

        tail_days = max(1, int(num_days * 0.4))

        return {
            "start_date": util.date2str(start_date),
            "end_date": util.date2str(end_date),
            "num_days": num_days,
            "days_per_partial_period": num_days - tail_days,
            "intermediate_date": util.date2str(end_date - timedelta(tail_days)),
            "descr": util.date2str(start_date) + '-' + util.date2str(end_date)
        }

    def do_mr_sample(self):
        dates = self.get_dates_context()
        cur_date = util.str2date(dates['start_date'])
        last_date = util.str2date(dates['end_date'])

        pr = util.ProcessRunner()

        while cur_date <= last_date:
            date_str_yt = util.date2str_yt(cur_date)
            table = "user_sessions/pub/watch_log_tskv/daily/{}/clean".format(date_str_yt)
            logging.info("going to sample {}".format(table))
            pr.add(
                "watch_log_tskv." + date_str_yt,
                self.get_sample_by_uid_command(
                    table=table,
                    frac=self.ctx[BaseFrac.name],
                    strat_config="strat-watch.json",
                    key_prefix="y"
                )
            )
            cur_date = cur_date + timedelta(1)

        pr.run()

    def updated_result_attrs(self, attrs):
        attrs = sample_base_task.Task.updated_result_attrs(self, attrs)
        dates = self.get_dates_context()
        attrs.update({
            'last_date': dates['end_date'],
            'intermediate_date': dates['intermediate_date'],
            'real_period': self.ctx[DaysPerPeriod.name],
            'period': dates['days_per_partial_period'],
        })
        return attrs


__Task__ = SampleUserCountersSessions
