#!/usr/bin/env python

# Get OAuth token here:
#     https://yql.yandex-team.ru/?settings_mode=token


import logging


# TEST_YQL_CACHE_FOLDER = "//home/infrasec/logfeller/yql_cache"
TEST_YT_TABLE = "//home/security/logfeller/streams/xiva-xivaserver-access-log/2020-06-13T03:10:00Z/stream.quarantine"


class QuarantineYQLQuery(object):
    """
    yt_account = "security"
    yql_cache_folder = "//home/{}/logfeller/yql_cache".format(yt_account)
    qq = QuarantineYQLQuery(YQL_TOKEN, yt_token, yql_cache_folder)
    qq.run_query(yt_table=TEST_YT_TABLE, yt_cluster="hahn")
    share_url = qq.get_share_url()
    result_rows = qq.get_result_rows()
    """

    def __init__(self, yql_token, yt_token, yql_cache_folder, udf_lib_name, udf_lib_url, query_version):
        from yql.api.v1.client import YqlClient

        # self._client = YqlClient(db=yt_cluster, token=token)
        self._client = YqlClient(token=yql_token)
        self._yt_token = yt_token
        self._yql_cache_folder = yql_cache_folder
        self._request = None
        self._udf_lib_name = udf_lib_name
        self._udf_lib_url = udf_lib_url
        self._query_version = query_version

    def _prepare_query(self, yt_table, yt_cluster):
        """
        Dont query all fields (event_data, secrets, original_record, masked_record)
        Only raw event_data and secret location info required.
        """

        query = """
            PRAGMA File("__PLACEHOLDER_UDF_LIB_NAME__", "__PLACEHOLDER_UDF_LIB_URL__");
            PRAGMA udf("__PLACEHOLDER_UDF_LIB_NAME__");
            PRAGMA yt.TmpFolder = "__PLACEHOLDER_YQL_CACHE__";

            $secret_type = ($type) -> {
                -- must be in sync with https://a.yandex-team.ru/arc/trunk/arcadia/security/ant-secret/snooper/internal/secret/secret_types.h#ESecretType
                return CASE $type
                    WHEN 1 THEN "Yandex OAuth"
                    WHEN 2 THEN "Yandex Session (cookie)"
                    WHEN 4 THEN "TVM Ticket"
                    ELSE "Unknown (please report this to security@yandex-team.ru)"
                END
            };

            $parse_secret = ($row, $info) -> {
                return ListMap(Yson::ConvertToList($info), ($x) -> {
                    $type = CAST(Yson::LookupUint64($x, "type") AS Uint32);
                    $pos = CAST(Yson::LookupUint64($x, "secret_pos") AS Uint32);
                    $len = CAST(Yson::LookupUint64($x, "secret_len") AS Uint32);
                    $secret = SUBSTRING($row, $pos, $len);
                    return AsStruct(
                        $secret as secret,
                        $secret_type($type) as secret_type,
                        $pos AS pos,
                        $len AS len,
                        AntSecret::Validate($type, $secret) as valid,
                    );
                });
            };

            $valid_count = ($info) -> {
                return ListLength(ListFilter($info, ($x) -> {
                    return $x.valid == true;
                }));
            };

            SELECT SOME(secret_info), SOME(event_data) FROM (
                SELECT secret_info.secret as secret, secret_info, event_data FROM (
                    SELECT
                        $parse_secret(event_data, secrets) as secrets,
                        event_data
                    FROM __PLACEHOLDER_CLUSTER__.`__PLACEHOLDER_TABLE__`
                    WHERE $valid_count($parse_secret(event_data, secrets)) > 0
                    LIMIT 150
                )
                FLATTEN LIST BY secrets as secret_info
                WHERE secret_info.valid == true
            )
            GROUP BY secret
            LIMIT 100
            ;
        """

        if self._query_version != 1:
            query = """
            PRAGMA File("__PLACEHOLDER_UDF_LIB_NAME__", "__PLACEHOLDER_UDF_LIB_URL__");
            PRAGMA udf("__PLACEHOLDER_UDF_LIB_NAME__");
            PRAGMA yt.TmpFolder = "__PLACEHOLDER_YQL_CACHE__";

            $valid_count = ($info) -> {
                return ListLength(ListFilter($info, ($x) -> {
                    return $x.validated == true;
                }));
            };

            SELECT SOME(secret_info), SOME(event_data) FROM (
                SELECT secret_info.secret as secret, secret_info, event_data FROM (
                    SELECT info, event_data FROM (
                        SELECT info, event_data, $valid_count(info) as valid_count FROM (
                            SELECT AntSecret::Search(event_data) as info, event_data
                            FROM __PLACEHOLDER_CLUSTER__.`__PLACEHOLDER_TABLE__`
                        )
                    )
                    WHERE valid_count > 0
                    LIMIT 10
                )
                FLATTEN LIST BY info AS secret_info
                WHERE secret_info.validated == true
            )
            GROUP BY secret
            LIMIT 100;
            """

        query = query.replace("__PLACEHOLDER_YQL_CACHE__", self._yql_cache_folder)
        query = query.replace("__PLACEHOLDER_CLUSTER__", yt_cluster)
        query = query.replace("__PLACEHOLDER_TABLE__", yt_table)
        query = query.replace("__PLACEHOLDER_UDF_LIB_URL__", self._udf_lib_url)
        query = query.replace("__PLACEHOLDER_UDF_LIB_NAME__", self._udf_lib_name)
        return query

    def run_query(self, yt_table=TEST_YT_TABLE, yt_cluster="hahn"):
        """
        Prepare query and run query.
        """

        query = self._prepare_query(yt_table, yt_cluster)
        self._request = self._client.query(query, syntax_version=1)
        self._request.run()

    def get_share_url(self):
        return self._request.share_url

    def get_result_rows(self):
        """
        Get list of raw data and secrets info dicts
        """
        from yt.yson import yson_to_json

        rows = list()

        results = self._request.get_results()  # BLOCKING
        table = results.table
        # table.fetch_full_data()   # by default returns only 100 rows. use to get more.

        for row in table.rows:
            rows.append({
                "raw_data": row[1],
                "secrets_info": yson_to_json(row[0])
            })

        return rows

    def clear_cache(self, remote_temp_files_directory, yt_cluster="hahn"):
        import yt.wrapper as yt

        yt.config["token"] = self._yt_token
        yt.config["remote_temp_files_directory"] = remote_temp_files_directory
        yt.config.set_proxy(yt_cluster)

        # clear all non tmp folders. tmp contains cached yql tables for sharing
        folders = list()
        try:
            folders = yt.list(self._yql_cache_folder)
            folders = list(filter(lambda x: x != "tmp", folders))
            logging.info("[+] QuarantineYQLQuery::clear_cache. listing _yql_cache_folder({}): {}".format(self._yql_cache_folder, folders))
        except yt.YtHttpResponseError:
            logging.info("[+] QuarantineYQLQuery::clear_cache. Error on listing _yql_cache_folder({})".format(self._yql_cache_folder))

        for folder in folders:
            try:
                cache_folder = self._yql_cache_folder + "/" + folder
                logging.info("[+] QuarantineYQLQuery::clear_cache. deleting cache_folder: {}".format(cache_folder))
                yt.remove(cache_folder, recursive=True)
            except yt.YtHttpResponseError:
                logging.info("[+] QuarantineYQLQuery::clear_cache. Error on deleting {}".format(cache_folder))

        # clear old cached shared tables.
        import datetime
        tmp_folder = self._yql_cache_folder.rstrip("/") + "/tmp"
        tablenames = yt.list(tmp_folder)
        rmcounter = 10
        for tablename in tablenames:
            tablepath = tmp_folder + "/" + tablename
            modification_time = yt.get(tablepath + "/@modification_time")
            table_dt = datetime.datetime.strptime(modification_time.split("T")[0], "%Y-%m-%d")
            # print(table_dt)
            old_dt = datetime.datetime.utcnow() - datetime.timedelta(days=365)
            if table_dt < old_dt:
                yt.remove(tablepath)
                rmcounter -= 1
            if rmcounter == 0:
                break
