from __future__ import print_function

import argparse
import codecs
import datetime
import logging
import os
import time
import uuid

from cached_property import cached_property
from library.python import resource
from textwrap import dedent
from yql.api.v1.client import YqlClient
from yql.client.explain import YqlSqlValidateRequest
import yql.library.embedded.python.run as embedded
from yt import yson

from crypta.lib.python.jinja_resource import JinjaResourceMixin
import crypta.lib.python.yql.client as yql_client
from crypta.lib.python.yql_runner.util_random import reset_random_state


def sandbox_mark(title):
    if os.getenv("YQL_MARK"):
        # title should has YQL substring, so if title is None or empty, set 'Crypta YQL'
        return "{title} {mark}".format(title=(title or "Crypta YQL"), mark=os.getenv("YQL_MARK"))
    return title


class EmbeddedYqlMixin(object):

    """ Mixin to run query in embedded yql """

    @cached_property
    def custom_tokens(self):
        return None

    @cached_property
    def factory(self):
        yt_clusters = [{"name": "ytcluster", "cluster": self.yt_proxy}]
        return embedded.OperationFactory(
            yt_clusters=yt_clusters,
            user_data=list(self.get_libs()),
            yt_token=self.get_token(),
            mrjob_binary=self.mrjob_binary,
            udf_resolver_binary=self.udf_resolver_binary,
            udfs_dir=self.udfs_dir,
            log_level=self.loglevel,
            custom_tokens=self.custom_tokens,
        )

    def __factory_run(self, query, mode):
        reset_random_state()
        self.last_op = self.factory.run(
            query,
            title=sandbox_mark(self.get_title()),
            syntax_version=self.syntax,
            attributes=yson.dumps({"script_name": self.get_title()}),
            mode=mode,
        )
        return self.last_op

    def _run_embedded(self, query):
        return self.__factory_run(query, "run").yson_result()

    def _validate_embedded(self, query):
        # todo: make correct validate
        self.__factory_run(query, "validate")
        print(dir(self.last_op))
        print(self.last_op.plan())
        return None


class ClientYqlMixin(object):

    """ Mixin to run query over yql api """

    @cached_property
    def client(self):
        test_env = yql_client.get_test_environment()
        return YqlClient(
            server=test_env.get("yql_server", self.yql_server),
            port=test_env.get("yql_port", self.yql_port),
            db=test_env.get("db"),
            db_proxy=self.yt_proxy,
            token=self.get_yql_token(),
        )

    def __get_row_request(self, query, action="run"):
        if action == "run":
            func = self.client.query
        elif action == "validate":
            func = YqlSqlValidateRequest
        else:
            raise Exception('Action should be "run" or "validate"')

        return func(query=query, syntax_version=self.syntax, title=sandbox_mark(self.get_title()))

    def __prepare_request(self, query, action="run"):
        request = self.__get_row_request(query, action)
        for item in self.get_libs():
            file_content = item["content"]
            if item["disposition"] == "resource":
                file_content = resource.find(file_content)
            elif item["disposition"] == "filesystem":
                with codecs.open(file_content, "r", encoding="utf-8") as ifile:
                    file_content = ifile.read()
            request.attached_files.append({"name": item["name"], "data": file_content, "type": "CONTENT"})

        script_name = self.get_title()
        if script_name:
            request.additional_attributes["script_name"] = script_name
        # force set loglevel ERROR for yql request CRYPTR-1735
        request.logger.setLevel(getattr(logging, self.loglevel))
        return request

    def _run_client(self, query, action="run"):
        request = self.__prepare_request(query, action=action)
        request.run()
        results = request.get_results()
        print(results.status)

        if not (results.is_ok and results.is_success):
            errors = [" - %s" % (error) for error in results.errors]
            raise Exception("\n".join(errors))

        return results

    def _validate_client(self, query):
        return self._run_client(query, action="validate")


class BaseParser(JinjaResourceMixin, EmbeddedYqlMixin, ClientYqlMixin):

    """ Run YQL query to parse logs """

    # output directories
    GRAPH_OUTPUT_DIR = "//home/crypta/{crypta_env}/state/graph"
    GRAPH_STREAM_DIR = "//home/crypta/{crypta_env}/state/graph/stream"
    INDEVICE_OUTPUT_DIR = "//home/crypta/{crypta_env}/state/graph/indevice"
    IDS_STORAGE_DIR = "//home/crypta/{crypta_env}/ids_storage"
    CRYPTA_PROFILES_EXPORT_DIR = "//home/crypta/{crypta_env}/profiles/export"
    TMP_DIR = "//home/crypta/{crypta_env}/tmp/graph/{folder}"

    OUTPUT_DIR = None
    QUERY_TEMPLATE = None
    _query = None

    def __init__(
        self,
        date,
        yt_proxy,
        pool,
        mrjob_binary=None,
        udf_resolver_binary=None,
        udfs_dir=None,
        loglevel="ERROR",
        limit=None,
        is_embedded=True,
        crypta_env=None,
        script_name=None,
        yql_server="yql.yandex.net",
        yql_port=443,
        is_local=False,
    ):
        self.date = date
        self.yt_proxy = yt_proxy
        self.pool = pool
        self.mrjob_binary = mrjob_binary
        self.udf_resolver_binary = udf_resolver_binary
        self.udfs_dir = udfs_dir
        self.loglevel = loglevel
        self.limit = limit
        self.is_embedded = (not is_local) and is_embedded
        self.yql_server = yql_server
        self.yql_port = yql_port
        self.is_yql_test = bool(yql_client.get_test_environment())
        self.is_local = self.is_yql_test or is_local
        if crypta_env:
            self.crypta_env = crypta_env
        else:
            self.crypta_env = os.environ.get("ENV_TYPE", "stable").lower().replace("stable", "production")
        self.script_name = script_name

    @property
    def syntax(self):
        return 1

    def render_query(self, **kwargs):
        """
        Create query from template
        By default will be added crypta/lib/python/yql_runner/query/header.sql to your QUERY_TEMPLATE
        """
        if self._query is None:
            # cached call for make sure the VERBOSE query random dir, be equal to RUN query dirs

            rendered_header = super(BaseParser, self).render("header.sql.j2", **kwargs)
            rendered_body = ""
            if self.QUERY_TEMPLATE is not None:
                rendered_body = super(BaseParser, self).render(self.QUERY_TEMPLATE, **kwargs)

            self._query = "\n".join((rendered_header, rendered_body))
        return self._query

    def get_context_data(self, **kwargs):
        """ Fill context data for template rendering """
        context = super(BaseParser, self).get_context_data(**kwargs)
        context.update(
            date=self.date,
            pool=self.pool,
            yt_proxy=self.yt_proxy,
            limit=self.limit,
            is_embedded=self.is_embedded,
            is_local=self.is_local,
            salt=str(time.time()),
            crypta_env=self.crypta_env,
            query_title=sandbox_mark(self.get_title()),
            udf_source=self.get_udf_source(),
            is_test_run=self.is_yql_test or context["is_test_run"]
        )
        context.update(self.get_dirs())
        return context

    def get_libs(self):
        """ Return list of libraries with extra UD(A)Fs """
        return [
            {
                "name": "aggregation_lib.sql",
                "content": "/lib/aggregation_lib.sql",
                "disposition": "resource",
                "type": "library",
            },
            {
                "name": "metrica_lib.sql",
                "content": "/lib/metrica_lib.sql",
                "disposition": "resource",
                "type": "library",
            },
            {
                "name": "sn_id_parse_lib.sql",
                "content": "/lib/sn_id_parse_lib.sql",
                "disposition": "resource",
                "type": "library",
            },
            {"name": "config.sql", "content": "/lib/config.sql", "disposition": "resource", "type": "library"},
            {
                "name": "ut_utils_lib.sql",
                "content": "/lib/ut_utils_lib.sql",
                "disposition": "resource",
                "type": "library",
            },
        ]

    def get_title(self):
        """ Retrun title to yql operation """
        text = (self.script_name or self.__doc__ or self.__class__.__name__ or "").strip()
        if not text:
            return None
        return "Crypta YQL [{env}] {text}".format(env=self.crypta_env, text=text)

    def get_udf_source(self):
        return os.environ.get("YQL_UDF_SOURCE") or "yt"

    def get_token(self):
        return os.environ.get("YT_TOKEN")

    def get_yql_token(self):
        """ required for client mode """
        return os.environ.get("YQL_TOKEN") or self.get_token()

    def get_tmp_dir(self):
        """
        To prevent equal table name at TRandGuid::GenGuid()
        when use embedded yql in luigi threads. We make different tmp dirs for each query call.
        """
        return self.TMP_DIR.format(crypta_env=self.crypta_env, folder=uuid.uuid4())

    def get_dirs(self):
        """ Get current environ and replace directories path to env """
        return dict(
            output_dir=self._get_dir_or_default("OUTPUT_DIR", self.OUTPUT_DIR),
            graph_output_dir=self._get_dir_or_default("GRAPH_YT_OUTPUT_FOLDER", self.GRAPH_OUTPUT_DIR),
            graph_stream_dir=self._get_dir_or_default("GRAPH_STREAM_FOLDER", self.GRAPH_STREAM_DIR),
            indevice_output_dir=self._get_dir_or_default("INDEVICE_YT_FOLDER", self.INDEVICE_OUTPUT_DIR),
            ids_storage_dir=self._get_dir_or_default("CRYPTA_IDS_STORAGE", self.IDS_STORAGE_DIR),
            crypta_profiles_export_dir=self._get_dir_or_default(
                "CRYPTA_PROFILES_EXPORT_DIR", self.CRYPTA_PROFILES_EXPORT_DIR
            ),
            # use default tmp dir, without uuid (because reset random state before each call)
            # tmp_dir=self.get_tmp_dir(),
        )

    def run(self, **kwargs):
        query = self.render_query(**kwargs)
        if self.is_embedded:
            return self._run_embedded(query)
        else:
            return self._run_client(query)

    def validate(self, **kwargs):
        query = self.render_query(**kwargs)
        if self.is_embedded:
            return self._validate_embedded(query)
        else:
            return self._validate_client(query)

    def _get_default_path(self, template):
        return template.format(crypta_env=self.crypta_env)

    def _get_dir_or_default(self, env, default):
        environ_table = os.environ.get(env, "").rstrip("/") or None
        default_table = self._get_default_path(default) if default is not None else None
        return environ_table or default_table


def timer(fun):
    """ Wrap fun to measure runtime """

    def wrapper():
        start = time.time()
        result = fun()
        end = time.time()
        print(
            dedent(
                """\
            \n\n
            START:\t{start:%Y-%m-%d %H:%M:%S}
            END:\t{end:%Y-%m-%d %H:%M:%S}
            TOTAL RUNTIME:\t{timedelta}
        """
            ).format(
                start=datetime.datetime.fromtimestamp(start),
                end=datetime.datetime.fromtimestamp(end),
                timedelta=datetime.timedelta(seconds=end - start),
            )
        )
        return result

    if os.environ.get("VERBOSE"):
        return wrapper
    return fun


def make_common_arg_parser(description):
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument("--cluster", help="YT cluster to use (default: %(default)s)", default="hahn.yt.yandex.net")
    parser.add_argument("--pool", help="YT pool to use (default: %(default)s)", default="crypta_graph")
    parser.add_argument("--loglevel", help="Logging level (default: %(default)s)", default="ERROR")
    parser.add_argument(
        "--mrjob", help="Path to mrjob binary (default: %(default)s)", default="./yql/tools/mrjob/mrjob"
    )
    parser.add_argument(
        "--udfr",
        help="Path to UDF resolver binary (default: %(default)s)",
        default="./yql/tools/udf_resolver/udf_resolver",
    )
    parser.add_argument(
        "--udfsdir", help="Path to UDF search directory (default: %(default)s)", default="./yql/udfs/common"
    )
    parser.add_argument("--limit", help="SQL Limit (use for speedup debugging)")
    parser.add_argument(
        "--client", help="Run query via YQL Client", action="store_false", dest="is_embedded", default=True
    )
    return parser


def make_arg_parser(description, date_range=False):
    # TODO: change to protobuf config, and rm this argparser
    parser = make_common_arg_parser(description)
    if date_range:
        parser.add_argument("date_start", nargs="?", help="Start date. Format YYYY-MM-DD")
        parser.add_argument("date_end", nargs="?", help="End date. Format YYYY-MM-DD")
    else:
        parser.add_argument(
            "date",
            nargs="?",
            help="Calc date. Format YYYY-MM-DD (default: yesterday %(default)s)",
            default=(datetime.date.today() - datetime.timedelta(days=1)).strftime("%Y-%m-%d"),
        )
    return parser
