import contextlib
import datetime
import json
import logging
import re
import time
import uuid

import crypta.lib.python.bt.conf.conf as conf
import yt.wrapper as yt

from dateutil.parser import parse as parse_dt
from cached_property import cached_property
from crypta.lib.python.bt.tasks import TransactionalYtTask
from crypta.lib.python.bt.workflow import IndependentTask, Parameter
from crypta.lib.python.solomon.reporter import create_throttled_solomon_reporter
from crypta.lib.python.yql_runner.task import YQLRunnerTask, ConfMixin

logger = logging.getLogger(__name__)


def split_log_path(path):
    log_name = re.search(r"/logs/([^/]+)/", path).group(1)

    last_times = {"5min": "23:55:00", "30min": "23:30:00", "1h": "23:00:00"}

    parts = path.split("/")
    if len(parts) >= 3 and parts[-2] in last_times.keys():  # logfeller logs
        period = parts[-2]
        return parts[-1].split("T")[1] == last_times[period], period, log_name
    return True, "1d", log_name


def get_table_date(path):
    suffix = str(path).split("/")[-1]
    try:
        return parse_dt(suffix)  # as well as %Y-%m-%d*** format
    except:
        try:
            return datetime.datetime.utcfromtimestamp(int(suffix.split("-", 1)[0]))  # as well as timestamp*** format
        except:
            return datetime.datetime.now()  # can't get date from table


def _is_mounted(ytclient, path):
    return ytclient.get("{0}/@tablets/0/state".format(path)) == "mounted"


def _mount_processed_table(ytclient, table):
    if not ytclient.exists(table):
        ytclient.create(
            "table",
            path=table,
            recursive=True,
            ignore_existing=True,
            attributes={
                "schema": [
                    {"name": "log_source", "required": True, "sort_order": "ascending", "type": "string"},
                    {"name": "path", "required": True, "sort_order": "ascending", "type": "string"},
                    {"name": "process_time", "required": False, "type": "any"},
                ],
                "dynamic": True,
                "optimize_for": "scan",
            },
        )

    if not _is_mounted(ytclient, table):
        ytclient.mount_table(table, sync=True)
        assert _is_mounted(ytclient, table), "Table still not mounted"


class ProcessedMixin(TransactionalYtTask):

    """Mount and read processed dyn table"""

    # limit count of tables to process at one task
    limit = Parameter(parse=int, default=1000)
    allowed_lag = Parameter(parse=int, default=int(1.5 * 3600))

    log_source = None
    # if True will not check that processed records match log_source
    update_all_log_sources = False

    processed = set()

    @cached_property
    def key(self):
        return self.__class__.__name__

    def mount_processed(self):
        table = conf.paths.stream.processed
        _mount_processed_table(self.yt, table)

    def input(self):
        raise NotImplementedError()

    def output(self):
        return []

    def postprocess_input_table(self, tbl):
        pass

    def postprocess_input_table_non_transactional(self, tbl):
        pass

    @cached_property
    def date(self):
        return datetime.date.today()

    @cached_property
    def finish_date(self):
        """
        Approx date when task will ready
        Needed to correct process app metrica
        until luigi has no logic to wait all log is import
        """
        return (datetime.datetime.now() + datetime.timedelta(hours=6)).date()

    @cached_property
    def ls(self):
        """return log source string representation"""
        return self.log_source.Name if self.log_source else ""

    @cached_property
    def output_idstorage_dir(self):
        return conf.paths.stream.id_storage

    @cached_property
    def output_extra_data_dir(self):
        return yt.ypath.ypath_join(conf.paths.stream.extra_data, self.key)

    @cached_property
    def uniqid(self):
        return "{time:.0f}-{uuid}".format(time=time.time(), uuid=uuid.uuid4())

    @cached_property
    def limit_day(self):
        return datetime.datetime.now() - datetime.timedelta(days=14)

    # C901 # noqa # TODO: make it simple again
    @contextlib.contextmanager
    def run_context(self):
        self.mount_processed()
        # Dyn table transaction can't has a parent (so use transaction split)
        # with self.yt.Transaction(type="tablet") as dynamic:  # noqa
        processed = self.processed

        if hasattr(self, "NO_USE_TRANSACTION"):
            trx_constructor = yt.Transaction
        else:
            trx_constructor = self.yt.Transaction

        affected_log_tables = set()

        with trx_constructor() as transaction:  # noqa
            self.input_tables = list(self.input())
            self.output_tables = list(self.output())

            self.unprocessed_tables = sorted(
                [x for x in self.input_tables if str(x) not in processed and get_table_date(x) > self.limit_day],
                key=get_table_date,
            )[: self.limit]

            logger.info("Input tables %s", self.input_tables)
            logger.info("Unprocessed tables %s", self.unprocessed_tables)
            logger.info("Output tables %s", self.output_tables)

            with self.yt.Transaction() as inner:  # noqa
                for t in self.output_tables:
                    self.yt.create("table", path=t, recursive=True, ignore_existing=True)

            with super(ProcessedMixin, self).run_context() as ctx:
                yield ctx

            logger.info("Update processed tables %s", conf.paths.stream.processed)
            # update output table
            for t in self.output_tables:
                self.yt.set(
                    "{output_table}/@processed".format(output_table=t),
                    {"key": self.key, "source": self.unprocessed_tables, "ls": self.ls},
                )

            for tbl in self.unprocessed_tables:
                # Merger tasks don't get the original log tables as input,
                # those are stored in 'processed' attribute instead
                data = self.yt.get_attribute(str(tbl), "processed", None)
                if data:
                    actual_tables = data["source"]
                else:
                    actual_tables = [str(tbl)]

                affected_log_tables = affected_log_tables | set(actual_tables)

        self._update_state(affected_log_tables)
        self._postprocess()

    def _assert_tables_ready(self, unprocessed_tables, input_tables):
        if len(unprocessed_tables):
            # ok let task run
            return True

        try:
            date_from_table = max(map(get_table_date, input_tables))
            lag = (datetime.datetime.now() - date_from_table).total_seconds()
        except Exception:
            logger.exception("Lag calculate error")
            # let it fail in next step
            return True

        if lag < self.allowed_lag:
            # so tables is not yet ready, wait
            return False

        raise RuntimeError("Tables still not ready lag is `{}'`".format(lag))

    def _update_state(self, affected_log_tables):
        # Dyn table transaction can't have a parent master transaction, so use transaction split)
        # 'http' backend doesn't support tablet transactions and 'rpc' doesn't support running operations
        # Set 'rpc' temporarily, and reset after the transaction is completed
        yt_client = self._init_yt()
        yt_client.config["backend"] = "rpc"
        with yt_client.Transaction(type="tablet"):
            parse_time = int(time.time())

            for atbl in affected_log_tables:
                recs = list(
                    yt_client.select_rows(
                        """
                    log_source, path, process_time FROM [{conf.paths.stream.processed}]
                    WHERE path = {tbl!r}
                """.format(
                            conf=conf, tbl=atbl
                        )
                    )
                )

                observed_log_sources = getattr(self, "log_sources", []) or [self.ls]

                updated = [
                    record
                    for record in recs
                    if record["log_source"] in observed_log_sources or self.update_all_log_sources
                ]

                if updated:
                    for record in updated:
                        if record["process_time"]:
                            record["process_time"][self.key] = parse_time
                        else:
                            record["process_time"] = {self.key: parse_time}
                    yt_client.insert_rows(conf.paths.stream.processed, updated, update=True)
                else:
                    records = []
                    pattern = {"path": atbl, "process_time": {self.key: parse_time}}
                    for source in observed_log_sources:
                        local = pattern.copy()
                        local.update(log_source=source)
                        records.append(local)
                    yt_client.insert_rows(conf.paths.stream.processed, records, update=True)

    def _postprocess(self):
        # postprocess only after success report to dyn table
        with self.yt.Transaction():
            for table in self.input_tables:
                self.postprocess_input_table(table)
        for table in self.input_tables:
            self.postprocess_input_table_non_transactional(table)


class SolomonMixin(object):
    @cached_property
    def solomon(self):
        return create_throttled_solomon_reporter(
            push_interval=0.1,
            oauth_token=conf.proto.SolomonConfig.Token,
            project=conf.proto.SolomonConfig.Project,
            cluster=conf.proto.SolomonConfig.Cluster,
            service=conf.proto.SolomonConfig.Service,
            url=conf.proto.SolomonConfig.Url,
        )

    def solomonify(self, start_at, finish_at):
        """Make solomon report"""
        if hasattr(self, "output_table"):
            rows = self.yt.get("{output_table}/@row_count".format(output_table=self.output_table))
        else:
            # soup hasn't output table (so not set 0 to solomon)
            # todo: may be set all soup table rows?
            rows = 0
        processed = len(self.unprocessed_tables)
        try:
            date_from_table = get_table_date(self.unprocessed_tables[-1])
            lag = (datetime.datetime.now() - date_from_table).total_seconds()
        except Exception:
            logger.exception("Lag calculate error")
            lag = -1

        for table in self.unprocessed_tables:
            try:
                date_from_table = get_table_date(table)
                is_eod, period, log_name = split_log_path(table)
                table_lag = (datetime.datetime.now() - date_from_table).total_seconds()
                self.solomon.set_value(
                    ts_datetime=date_from_table,
                    sensor="log",
                    value=table_lag,
                    labels={"task": self.key, "ls": self.ls, "table": log_name, "eod": is_eod, "granularity": period},
                )
            except Exception:
                logger.exception("Table notify solomon error")

        self.solomon.set_value("output_rows", rows, labels={"task": self.key, "ls": self.ls})
        self.solomon.set_value("processed_tables", processed, labels={"task": self.key, "ls": self.ls})
        self.solomon.set_value("duration", finish_at - start_at, labels={"task": self.key, "ls": self.ls})
        self.solomon.set_value("lag", lag, labels={"task": self.key, "ls": self.ls})
        # sleep 1 sec, to commit throttled solomon client queue
        time.sleep(1)


class StreamBaseTask(ConfMixin, ProcessedMixin, SolomonMixin, IndependentTask):

    """Take part of stream tables and parse it"""

    source_tables = Parameter(parse=json.loads, default="[]")
    skip_stream_on_testing = Parameter(
        parse=lambda value: value.lower() in {"+", "yes", "true", "t", "1"}, default="True"
    )
    enable_cloud = Parameter(parse=lambda value: value.lower() in {"+", "yes", "true", "t", "1"}, default="True")

    @cached_property
    def processed(self):
        """Return all processed tables for current key"""
        # TODO: either remove or define source_tables in this class
        if hasattr(self, "source_tables") and self.source_tables:
            # force read all table from params
            return set()

        query = """
            path FROM [{conf.paths.stream.processed}]
            WHERE log_source = {ls!r}
                AND (
                    process_time != NULL
                    AND try_get_int64(process_time, {key!r}) != NULL
                )
            ORDER BY regex_extract("(\\\\d{{4}}-\\\\d{{2}}-\\\\d{{2}})", path, "\\\\1") DESC
            LIMIT 10000
        """.format(
            conf=conf, key="/" + self.key, ls=self.log_source.Name
        )

        return set(item["path"] for item in self.yt.select_rows(query, format="json"))

    def observed_paths(self):
        """Should return iterable of observed YT folders to lookup input logs"""
        raise NotImplementedError()

    def input(self):
        """Lookup all observed path and list tables"""
        if self.source_tables:
            # read only tabels from params
            for table in self.source_tables:
                yield table
            return

        for path in self.observed_paths():
            for table in self.yt.list(path, absolute=True):
                yield table

    def output(self):
        """Return path to table result"""
        return [yt.ypath.ypath_join(conf.paths.stream.storage, self.uniqid)]

    @contextlib.contextmanager
    def run_context(self):
        start_at = time.time()
        with super(StreamBaseTask, self).run_context() as ctx:
            self.output_table = str(list(self.output())[0])
            yield ctx
        finish_at = time.time()
        self.solomonify(start_at, finish_at)

    @cached_property
    def unprocessed_tables_final(self):
        """
        Return unprocessed_tables on develop and production environment,
        and take 1/10 of unprocessed_tables on prestable
        """
        if (self.crypta_env == "testing") and self.skip_stream_on_testing:
            logger.info("Skip stream on testing")
            return self.unprocessed_tables[::10]
        return self.unprocessed_tables


class StreamBaseYqlTask(YQLRunnerTask, StreamBaseTask):

    """YQL Driven RUN() Take part of stream tables and parse it"""

    def get_context_data(self, **kwargs):
        """Context for jinja rendering"""
        context = super(StreamBaseYqlTask, self).get_context_data(**kwargs)
        context.update(
            stream=True,
            date=self.date,
            uniqid=self.uniqid,
            input=self.unprocessed_tables_final,
            output=self.output_table,
            output_extra_data_dir=self.output_extra_data_dir,
            output_idstorage_dir=self.output_idstorage_dir,
            enable_cloud=self.enable_cloud,
        )
        return context
