from datetime import date, timedelta
import logging
import collections

from sandbox.sandboxsdk.parameters import (
    SandboxStringParameter,
    SandboxIntegerParameter,
    SandboxFloatParameter,
)
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

from sandbox.projects.common.userdata import sample_base_task, util, mr_base_task


class LastDate(SandboxStringParameter):
    name = 'last_date'
    description = 'Start date of period, YYYYMMDD (optional, will use last available period if empty):'
    required = False
    group = mr_base_task.INPUT_PARAMS_GROUP_NAME


class DaysPerPeriod(SandboxIntegerParameter):
    name = 'days_per_period'
    description = 'Days per period:'
    required = True
    default_value = 7
    group = mr_base_task.INPUT_PARAMS_GROUP_NAME


class NumberOfPeriods(SandboxIntegerParameter):
    name = 'num_periods'
    description = 'Number of periods to sample:'
    required = True
    default_value = 4
    group = mr_base_task.INPUT_PARAMS_GROUP_NAME


class SpyLogBaseFrac(SandboxFloatParameter):
    name = 'spy_log_frac'
    description = 'Base fraction of spy_log to collect:'
    required = True
    default_value = 0.0001
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class SimilarGroupBaseFrac(SandboxFloatParameter):
    name = 'similargroup_frac'
    description = 'Base fraction of similargroup to collect:'
    required = True
    default_value = 0.0002
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class WhiteUsersFrac(SandboxFloatParameter):
    name = 'white_users_frac'
    description = 'Take this fraction of users "white":'
    required = True
    default = 1
    group = mr_base_task.MISC_PARAMS_GROUP_NAME


class SampleUserBrowseSessions(sample_base_task.Task):
    """
    Sample spy_log and similargroup sessions to test user_browse/update.py
    """

    type = 'SAMPLE_USER_BROWSE_SESSIONS'

    input_parameters = util.smart_join_params(
        sample_base_task.Task.input_parameters,
        LastDate,
        DaysPerPeriod,
        NumberOfPeriods,
        SpyLogBaseFrac,
        SimilarGroupBaseFrac,
        WhiteUsersFrac
    )

    def get_strat_data_dir(self):
        return 'quality/user_browse/scripts/testing'

    def do_get_dates_context(self):
        num_days = self.ctx['num_periods'] * self.ctx['days_per_period']
        if self.ctx.get(LastDate.name):
            end_date = util.str2date(self.ctx[LastDate.name])
        else:
            end_date = date.today() - timedelta(3)

        if end_date >= date.today():
            raise SandboxTaskFailureError("{} is in the future".format(util.date2str(end_date)))
        start_date = end_date - timedelta(num_days - 1)

        return {
            "start_date": util.date2str(start_date),
            "end_date": util.date2str(end_date),
            "num_days": num_days,
            "descr": util.date2str(start_date) + "-" + util.date2str(end_date)
        }

    def do_mr_sample(self):
        dates = self.get_dates_context()

        cur_date = util.str2date(dates['start_date'])
        last_date = util.str2date(dates['end_date'])

        pr = util.ProcessRunner()
        pr_white = util.ProcessRunner()

        while cur_date <= last_date:
            date_str = util.date2str(cur_date)
            date_str_yt = util.date2str_yt(cur_date)

            pr.add(
                "spylog." + date_str,
                self.get_sample_by_uid_command(
                    table="user_sessions/pub/spy_log/daily/{}/clean".format(date_str_yt),
                    frac=self.ctx[SpyLogBaseFrac.name],
                    strat_config="strat-spy.json",
                    key_prefix="y"
                )
            )
            pr.add(
                "similargroup." + date_str,
                self.get_sample_by_uid_command(
                    table="user_sessions/pub/similargroup/daily/{}/clean".format(date_str_yt),
                    frac=self.ctx[SimilarGroupBaseFrac.name],
                    strat_config="strat-similargroup.json",
                    key_prefix="sg"
                )
            )

            pr_white.add(
                "white_users." + date_str,
                "{env} {bin_dir}/user_sessions_create_synthetic_white_users "
                "--server {server} --frac {frac} --salt '{salt}' "
                "--source {dst_prefix}user_sessions/pub/spy_log/daily/{date_yt}/clean "
                "--dest {dst_prefix}home/antifraud/daily/cleaning/{date_yt}/spy_log_white_users",
                env=self.get_client_environ_str(),
                bin_dir=self.ctx["bin_dir"],
                server=self.ctx["mr_server"],
                salt=self.ctx["sampling_salt"],
                frac=self.ctx[WhiteUsersFrac.name],
                dst_prefix=self.get_tables_prefix(),
                date_yt=date_str_yt
            )

            cur_date += timedelta(1)

        tables = self.mr_client.get_tables_list(
            self.get_mr_src_prefix() + 'userfeat/user_browse/main_comm_q'
        )
        logging.info("got {} as main_comm_q tables".format(tables))

        by_date = collections.defaultdict(dict)
        for table in tables:
            parts = table.split("/")
            if len(parts) < 4:
                continue
            kind = parts[-2]
            d = parts[-1][-8:]
            if not d.isdigit():
                continue
            by_date[d][kind] = table

        logging.info("they are {} ".format(by_date))

        latest = None
        for d in sorted(by_date.keys())[::-1]:
            if set(by_date[d].keys()) == set(["doppnorm", "aggnorm"]):
                latest = by_date[d]
                self.ctx["comm_cache_date"] = d
                break
        else:
            raise SandboxTaskFailureError("Can't choose latest date for comm_cache")

        logging.info("The latest one are {}".format(latest))
        for norm in 'doppnorm', 'aggnorm':
            table = latest[norm]
            pr.add(
                "main_comm_q." + norm,
                self.get_mr_sample_command(
                    table=table,
                    dst_table="user_browse/main_comm_q/{}/_{}".format(
                        norm, self.ctx["comm_cache_date"]
                    ),
                    count=50000
                )
            )
        pr.run()
        pr_white.run()

    def updated_result_attrs(self, attrs):
        attrs = sample_base_task.Task.updated_result_attrs(self, attrs)
        dates = self.get_dates_context()
        attrs.update({
            'days_per_period': self.ctx[DaysPerPeriod.name],
            'num_periods': self.ctx[NumberOfPeriods.name],
            'last_date': dates['end_date'],
            'comm_cache_date': self.ctx["comm_cache_date"]
        })
        return attrs


__Task__ = SampleUserBrowseSessions
