import time
from datetime import datetime, timedelta

from sandbox.sandboxsdk.parameters import (
    SandboxStringParameter,
    SandboxIntegerParameter,
    SandboxFloatParameter,
)

from sandbox.sandboxsdk.errors import SandboxTaskFailureError

from sandbox.projects.common.userdata import sample_base_task, util, mr_base_task


class EndTime(SandboxStringParameter):
    name = 'end_time'
    description = 'Last timestamp of fast user_sessions (optional, will use last available period if empty):'
    required = False
    group = mr_base_task.INPUT_PARAMS_GROUP_NAME


class TableNameTemplate(SandboxStringParameter):
    name = 'table_name_template'
    description = 'Template for table name to sample:'
    required = True
    group = mr_base_task.INPUT_PARAMS_GROUP_NAME


class TimeStep(SandboxIntegerParameter):
    name = 'time_step'
    description = 'Time interval between periods, seconds (should be multiple of 1800):'
    required = True
    default_value = 1800
    group = mr_base_task.INPUT_PARAMS_GROUP_NAME


class NumPeriods(SandboxIntegerParameter):
    name = 'num_periods'
    description = 'Number of periods to sample:'
    required = True
    default_value = 25
    group = mr_base_task.INPUT_PARAMS_GROUP_NAME


class RecsFrac(SandboxFloatParameter):
    name = 'recs_frac'
    description = 'Base fraction of sessions to collect:'
    default_value = 0.001
    required = True
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class SampleUserDataFastSessions(sample_base_task.Task):
    """
    Sample fast user_sessions
    """

    type = 'SAMPLE_USERDATA_FAST_SESSIONS'

    input_parameters = util.smart_join_params(
        sample_base_task.Task.input_parameters,
        TableNameTemplate,
        EndTime,
        TimeStep,
        NumPeriods,
        RecsFrac
    )

    def do_get_dates_context(self):
        if self.ctx.get(EndTime.name):
            end_time = int(self.ctx[EndTime.name])
        else:
            end_time = int(time.mktime((datetime.now() - timedelta(hours=10)).timetuple()))
            end_time -= end_time % 1800

        if end_time % 1800 != 0:
            raise SandboxTaskFailureError("end_time={} is not a multiple of 1800".format(end_time))

        if int(self.ctx[TimeStep.name]) % 1800 != 0:
            raise SandboxTaskFailureError("Time step is not a multiple of 1800")

        if datetime.fromtimestamp(end_time) >= datetime.now():
            raise SandboxTaskFailureError("{} is in the future".format(end_time))

        start_time = end_time - int(self.ctx[TimeStep.name]) * (int(self.ctx[NumPeriods.name]) - 1)

        return {
            "start_time": start_time,
            "end_time": end_time,
            'descr': "yandex_{}-{}-{}".format(self.ctx[TableNameTemplate.name], start_time, end_time)
        }

    def do_mr_sample(self):
        pr = util.ProcessRunner()
        dates = self.get_dates_context()

        cur_time = int(dates['start_time'])
        last_time = int(dates['end_time'])
        step = int(self.ctx[TimeStep.name])
        name_template = self.ctx[TableNameTemplate.name]
        while cur_time <= last_time:
            pr.add(
                "user_sessions." + str(cur_time),
                self.get_sample_by_uid_command(
                    table=name_template.format(cur_time),
                    frac=self.ctx[RecsFrac.name]
                )
            )
            cur_time += step
        pr.run()

    def updated_result_attrs(self, attrs):
        attrs = sample_base_task.Task.updated_result_attrs(self, attrs)
        dates = self.get_dates_context()
        attrs.update({
            "start_time": dates["start_time"],
            "end_time": dates["end_time"],
            "time_step": int(self.ctx["time_step"]),
            "num_periods": int(self.ctx["num_periods"])
        })
        return attrs


__Task__ = SampleUserDataFastSessions
