# coding=utf-8
import logging
import datetime

from cached_property import cached_property
import yt.wrapper as yt
from yt.wrapper.ypath import ypath_join

from crypta.lib.python.yql_runner.base_parser import BaseParser
from crypta.ltp.viewer.lib import ltp_logs


logger = logging.getLogger(__name__)


class LTPIndexUpdate(BaseParser):

    QUERY_TEMPLATE = "main.sql.j2"
    MIN_HISTORY_DATE_ATTRIBUTE = "_min_history_date"
    MAX_HISTORY_DATE_ATTRIBUTE = "_max_history_date"
    LOGS_ATTRIBUTE = "_logs"
    DATE_FORMAT = "%Y-%m-%d"
    DEFAULT_HISTORY_PERIOD = 365  # in days

    def __init__(
        self,
        yt_proxy,
        pool,
        crypta_ltp_viewer_udf_url,
        hdd_index_path=None,
        start_date=None,
        end_date=None,
        min_history_date=None,
        is_embedded=False,
        **kwargs
    ):
        date = datetime.datetime.today().strftime(self.DATE_FORMAT)
        super(LTPIndexUpdate, self).__init__(date, yt_proxy, pool, is_embedded=is_embedded, **kwargs)
        self.hdd_index_path = hdd_index_path or ypath_join(self.get_default_ltp_dir(), "ids_index_compact_hdd")

        self.start_date = start_date or self.get_start_date(self.hdd_index_path)
        logger.info("start_date = %s", self.start_date)

        self.end_date = end_date or self.get_end_date()
        logger.info("end_date = %s", self.end_date)

        if self.start_date > self.end_date:
            raise RuntimeError("Start date is bigger than end date.")

        self.min_history_date = min_history_date or self.get_default_min_history_date()
        logger.info("min_history_date = %s", self.min_history_date)

        self.current_min_history_date = yt.get_attribute(
            self.hdd_index_path, self.MIN_HISTORY_DATE_ATTRIBUTE, default=self.min_history_date
        )
        logger.info("current_min_history_date = %s", self.current_min_history_date)

        self.current_max_history_date = yt.get_attribute(
            self.hdd_index_path, self.MAX_HISTORY_DATE_ATTRIBUTE, default=self.end_date
        )
        logger.info("current_max_history_date = %s", self.current_max_history_date)

        if len(self.input_paths) < 1:
            raise RuntimeError(
                "There are no prepared tables for date range from {start} to {end} in directories: {dirs}.".format(
                start=self.start_date,
                end=self.end_date,
                dirs=",\n".join(log_dir for _, log_dir in self.log_dirs)
            ))

        logger.info("input_paths:\n\t%s", ",\n\t".join(log_dir for _, log_dir in self.log_dirs))

        self.crypta_ltp_viewer_udf_url = crypta_ltp_viewer_udf_url

    def get_default_ltp_dir(self):
        return "//home/crypta/{env}/portal/ltp".format(env=self.crypta_env)

    def get_start_date(self, table):
        max_history_date_attribute = yt.get_attribute(table, self.MAX_HISTORY_DATE_ATTRIBUTE, default="None")

        if max_history_date_attribute == "None":
            raise RuntimeError(
                "Attribute {attr} not set on table {table}. Please, pass or set it and try again.".format(
                    attr=self.MAX_HISTORY_DATE_ATTRIBUTE, table=table
                )
            )

        shifted_date = datetime.datetime.strptime(max_history_date_attribute, self.DATE_FORMAT) + datetime.timedelta(
            days=1
        )
        shifted_date_str = shifted_date.strftime(self.DATE_FORMAT)

        today_str = datetime.datetime.today().strftime("self.DATE_FORMAT")
        if shifted_date_str >= today_str:
            raise RuntimeError("The maximum freshness of one day is already there.")

        return shifted_date.strftime(self.DATE_FORMAT)

    @cached_property
    def computed_logs(self):
        return yt.get_attribute(self.hdd_index_path, self.LOGS_ATTRIBUTE, default=[])

    def get_end_date(self):
        yesterday = datetime.datetime.today() - datetime.timedelta(days=1)

        return yesterday.strftime(self.DATE_FORMAT)

    def get_default_min_history_date(self):
        yesterday = datetime.datetime.today() - datetime.timedelta(days=self.DEFAULT_HISTORY_PERIOD)

        return yesterday.strftime(self.DATE_FORMAT)

    @cached_property
    def log_dirs(self):
        return [(log.name, log.path) for log in ltp_logs.LOGS]

    @cached_property
    def input_paths(self):
        input_paths = []

        for source, dir_path in self.log_dirs:
            for date in yt.list(dir_path, absolute=False):
                if self.min_history_date <= date <= self.end_date:
                    input_paths.append((source, dir_path))
                    break

        return input_paths

    def get_context_data(self, **kwargs):
        context = super(LTPIndexUpdate, self).get_context_data(**kwargs)
        context.update(
            start_date=self.start_date,
            end_date=self.end_date,
            min_history_date=self.min_history_date,
            current_min_history_date=self.current_min_history_date,
            input_paths=self.input_paths,
            index=self.hdd_index_path,
            computed_logs=self.computed_logs,
            crypta_ltp_viewer_udf_url=self.crypta_ltp_viewer_udf_url,
            user_attrs=yt.yson.dumps({
                self.MIN_HISTORY_DATE_ATTRIBUTE: min(self.start_date, self.current_min_history_date, self.min_history_date),
                self.MAX_HISTORY_DATE_ATTRIBUTE: max(self.end_date, self.current_max_history_date),
                self.LOGS_ATTRIBUTE: list(log_dir for _, log_dir in self.log_dirs),
            }).decode(),
        )
        return context

    @cached_property
    def output_schema(self):
        return [
            {"name": "id", "type": "string", "sort_order": "ascending"},
            {"name": "id_type", "type": "string", "sort_order": "ascending"},
            {"name": "sources", "type": "any"},
            {"name": "recs_count", "type": "uint64"},
        ]

    def run(self, **kwargs):
        log_date = self.start_date if self.start_date == self.end_date else f"{self.start_date} {self.end_date}"
        logger.info("Running LTP index update for %s", log_date)
        super(LTPIndexUpdate, self).run(**kwargs)
        logger.info("Finished LTP index updating")
