# coding: utf-8
import datetime
import logging
import os
import re
import tempfile
from __builtin__ import staticmethod
from functools import partial

import sandbox.common.types.client as ctc
import sandbox.common.types.task as ctt
import sandbox.sdk2 as sdk2
from bson import json_util
from sandbox.projects.common import binary_task
from sandbox.projects.common.arcadia import sdk as arcadia_sdk
from sandbox.projects.common.environments import MongodbEnvironment
from sandbox.sdk2 import parameters
from sandbox.sdk2.helpers import subprocess
from multiprocessing import Pool

import bson_schema_converter
import util

MONGOEXPORT_DEFAULT_OPTIONS = {
    "--readPreference": "secondary",
    "--type": "json",
    "--jsonFormat": "canonical",
}

ROBOT_TOKEN_SECRET = "sec-01e0r2m2znx2nqqvv2ygnfdxy4"
ARCADIA_TOKEN_KEY = "ROBOT_MEDIABILLING_OAUTH"
YT_TOKEN_KEY = 'ROBOT_MEDIABILLING_YT_TOKEN'


class MediabillingMongoExportToYt(binary_task.LastBinaryTaskRelease, sdk2.Task):
    """ mongoexport | parser script | yt write """

    class Requirements(sdk2.Task.Requirements):
        environments = (
            MongodbEnvironment(),
        )
        client_tags = ctc.Tag.LINUX_XENIAL

    class Parameters(sdk2.Task.Parameters):
        binary_release = binary_task.binary_release_parameters(stable=True)
        kill_timeout = 20 * 60
        description = "Dump mongo to yt"
        spec_path = parameters.String(
            'Relative path to spec file',
            required=True,
            description="Used as 'arcadia/{spec_path}'"
        )
        arcadia_path = parameters.String(
            'Path to arcadia (either trunk or users branch)',
            required=True,
            default="arcadia-arc:/#trunk",
            description="For your branch use 'arcadia-arc:/#users/login/my-branch'"
        )

    def on_execute(self):
        super(MediabillingMongoExportToYt, self).on_execute()
        self.init_yt()
        oauth_token = sdk2.yav.Secret(ROBOT_TOKEN_SECRET).data()[ARCADIA_TOKEN_KEY]
        with arcadia_sdk.mount_arc_path(self.Parameters.arcadia_path, arc_oauth_token=oauth_token) as arcadia_path:
            spec_file = os.path.join(arcadia_path, self.Parameters.spec_path)
            spec_dir = os.path.abspath(os.path.join(spec_file, os.pardir))

            self.arcadia_dir = arcadia_path
            self.spec_dir = spec_dir

            spec = util.load_yaml(spec_file)
            current_datetime = datetime.datetime.now() - datetime.timedelta(minutes=1)
            # replace with tempfile.TemporaryDirectory() in python3
            temp_dir = tempfile.mkdtemp()
            query = self.build_mongo_query(spec, current_datetime)
            mongoexport_path = os.path.join(temp_dir, "mongoexport.json")

            connections = util.load_yaml(self.resolve_path(spec["mongo"]["connection"]))
            if self.is_stable_env():
                connection = connections["production"]
            else:
                self.set_info("Running unstable version, so testing db connection will be used")
                connection = connections["testing"]
            self.run_mongoexport(spec["mongo"], query, mongoexport_path, connection, temp_dir)

            lines_count = util.lines_count(mongoexport_path)
            self.set_info("Exported {} lines from mongo".format(lines_count))
            if lines_count == 0:
                return

            mapper_spec_path = self.resolve_path(spec["schema"])
            mapped_json_path = os.path.join(temp_dir, "mapped.json")
            mapper_spec = util.load_yaml(mapper_spec_path)
            self.convert(mongoexport_path, mapped_json_path, mapper_spec, lines_count)

            from yt.wrapper import Transaction
            with Transaction(client=self.yt_client) as tx:
                schema = self.load_schema_from_mapper(mapper_spec)
                result_path = self.create_table(current_datetime, schema, spec)

                export_path = self.yt_client.create_temp_table(prefix="export",
                                                               attributes={"schema": schema})
                self.yt_write(tx, mapped_json_path, export_path)

                source_tables = [export_path]
                if spec["mode"] == "delta":
                    source_tables.append(result_path)

                if "sort" in spec["yt"]:
                    self.yt_client.run_sort(source_table=source_tables,
                                            destination_table=result_path,
                                            sort_by=spec["yt"]["sort"])
                else:
                    self.yt_client.concatenate(source_paths=source_tables,
                                               destination_path=result_path)

                last_doc_datetime = self.get_last_doc_timestamp(mongoexport_path, spec["mongo"]["timestamp_field"]) \
                    if spec["mode"] == "delta" else None
                self.save_export_datetime(spec["name"], current_datetime, last_doc_datetime)
                self.set_info(u'Exported to <a href="https://yt.yandex-team.ru/hahn/navigation?path={0}">{0}</a>'
                              .format(result_path), do_escape=False)

    def create_table(self, current_datetime, schema, spec):
        table_path = self.build_table_name(current_datetime, spec["yt"]["path"])

        if not self.is_stable_env():
            self.set_info("Running unstable version, so data will be exported into tmp table")
            return self.yt_client.create_temp_table(prefix=table_path.split("/")[-1],
                                                    attributes={"schema": schema})
        if not self.yt_client.exists(table_path):
            self.yt_client.create("table", table_path, attributes={"schema": schema}, recursive=True)
        return table_path

    def is_stable_env(self):
        is_arcadia_trunk = self.Parameters.arcadia_path == "arcadia-arc:/#trunk"
        is_task_stable = self.Parameters.binary_executor_release_type == ctt.ReleaseStatus.STABLE
        return is_arcadia_trunk and is_task_stable

    def init_yt(self):
        from yt.wrapper import YtClient
        self.yt_client = YtClient(proxy='hahn.yt.yandex.net', token=util.yt_token())

    @staticmethod
    def get_last_doc_timestamp(mongoexport_path, timestamp_field):
        with open(mongoexport_path) as f:
            max_timestamp = None
            for line in f:
                doc_timestamp = json_util.loads(line)[timestamp_field]
                max_timestamp = max(doc_timestamp, max_timestamp) if max_timestamp else doc_timestamp
            return max_timestamp

    def convert(self, source_path, target_path, mapper_spec, lines_count):
        batch = []
        lines_processed = 0
        mapper = partial(bson_schema_converter.convert, spec=mapper_spec)
        pool = Pool()
        with open(target_path, "w") as target:
            with open(source_path, "r") as source:
                for line in source:
                    batch.append(line)
                    if len(batch) >= 10000:
                        self.flush(batch, mapper, target, pool)
                        lines_processed += len(batch)
                        batch = []
                        logging.info("Processed {}/{} from {}".format(lines_processed, lines_count, source_path))
                if len(batch) > 0:
                    self.flush(batch, mapper, target, pool)

    def flush(self, batch, mapper, target, pool):
        target.writelines(pool.map(mapper, batch))

    def build_mongo_query(self, spec, current_datetime):
        if spec["mode"] == "snapshot":
            return {}
        elif spec["mode"] == "delta":
            previous_datetime = self.get_last_exported_doc_datetime(spec["name"])
            timestamp_field = spec["mongo"]["timestamp_field"]
            query = {
                timestamp_field: {
                    "$lte": current_datetime
                }
            }
            if previous_datetime:
                query[timestamp_field]["$gt"] = datetime.datetime.fromtimestamp(previous_datetime / 1000.0)
            return query
        else:
            raise RuntimeError("Unknown mode: " + spec["mode"])

    def yt_write(self, tx, json_path, table_path):
        cmd = [
            os.path.join(self.arcadia_dir, "ya"),
            "tool",
            "yt",
            "--tx",
            tx.transaction_id,
            "write",
            "--table",
            table_path,
            "--format",
            "<encode_utf8=%false>json",
            "--config",
            "{write_parallel={enable=%true;unordered=%true;}}",
        ]
        with sdk2.helpers.ProcessLog(self, logger='yt_write') as pl:
            process = subprocess.Popen(cmd,
                                       stdin=open(json_path, 'r'),
                                       stdout=pl.stdout,
                                       stderr=pl.stderr,
                                       env=MediabillingMongoExportToYt.copy_env_for_yt())
            util.wait_process(process)

    @staticmethod
    def build_table_name(current_datetime, yt_path):
        return yt_path \
            .replace("%date%", current_datetime.strftime('%Y-%m-%d')) \
            .replace("%datetime%", current_datetime.strftime('%Y-%m-%dT%H:%M:%S'))

    @staticmethod
    def copy_env_for_yt():
        environment = os.environ.copy()
        environment['YT_TOKEN'] = util.yt_token()
        environment['YT_PROXY'] = "hahn.yt.yandex.net"
        environment['YT_LOG_LEVEL'] = 'DEBUG'
        return environment

    @staticmethod
    def load_schema_from_mapper(mapper_spec):
        import yt.yson as yson
        json_schema = []
        for name, field in mapper_spec.items():
            yson_type = MediabillingMongoExportToYt.to_yson_type(field["type"])
            json_schema.append({
                "name": name,
                "type": yson_type,
                "required": field.get("required", False) if yson_type != "any" else False
            })
        schema = yson.YsonList(json_schema)
        schema.attributes["strict"] = True
        return schema

    @staticmethod
    def to_yson_type(type_name):
        return "any" if type_name in ["array", "object"] else type_name

    def run_mongoexport(self, mongo_spec, query, filename, connection, temp_dir):
        config_file = os.path.join(temp_dir, "mongo.conf")
        self.write_mongo_config(config_file, resolve_value(connection["password"]))
        cmd = [
            "mongoexport",
            "-vvvvvv",
            "--uri", resolve_value(connection["uri"]),
            "--username", resolve_value(connection["username"]),
            "--config", config_file,
            "--authenticationDatabase", resolve_value(connection["authenticationDatabase"]),
            "--collection", mongo_spec["collection"],
            "--out", filename,
        ]
        additional_options = util.merge_two_dicts(MONGOEXPORT_DEFAULT_OPTIONS,
                                                  mongo_spec.get("mongoexport_options", {}))

        if query or "--query" in additional_options:
            additional_query = json_util.loads(additional_options.get("--query", "{}"))
            combined_query = self.combine_mongo_queries(query, additional_query)
            additional_options["--query"] = json_util.dumps(combined_query,
                                                            json_options=json_util.RELAXED_JSON_OPTIONS)

        for option_name, value in additional_options.items():
            cmd.extend([option_name, str(value)])

        with sdk2.helpers.ProcessLog(self, logger='mongoexport') as pl:
            process = subprocess.Popen(cmd, stdout=pl.stdout, stderr=pl.stderr)
            util.wait_process(process)

    @staticmethod
    def combine_mongo_queries(query_a, query_b):
        if query_a and query_b:
            return {"$and": [query_a, query_b]}
        return query_a if query_a else query_b

    def get_last_exported_doc_datetime(self, export_name):
        table_path = self.history_yt_path(export_name) + "/last"
        if not self.yt_client.exists(table_path):
            return None
        timestamps = [row["last_document_timestamp"] for row in self.yt_client.read_table(table_path, format='json')]
        return max(timestamps) if timestamps else None

    def save_export_datetime(self, export_name, export_datetime, last_doc_timestamp):
        import yt.wrapper as yt
        row = {
            "name": export_name,
            "last_export_timestamp": util.msk_millis(export_datetime),
            "last_document_timestamp": util.msk_millis(last_doc_timestamp) if last_doc_timestamp else None,
        }

        history_path = self.history_yt_path(export_name)
        last_export_table = history_path + "/last"
        history_table = history_path + "/history"
        if not self.yt_client.exists(history_path):
            self.yt_client.create("map_node", path=history_path, recursive=True)
        self.yt_client.write_table(last_export_table, [row], format=yt.JsonFormat())
        self.yt_client.write_table("<append=true>" + history_table, [row], format=yt.JsonFormat())
        self.yt_client.run_sort(history_table, sort_by="last_export_timestamp")

    def write_mongo_config(self, config_file, password):
        with open(config_file, "w") as f:
            f.write("password: {}".format(password))

    def history_yt_path(self, export_name):
        env = "production" if self.is_stable_env() else "testing"
        return "//home/mediabilling/exports-history/{}/{}".format(env, export_name)

    def resolve_path(self, path):
        root = self.spec_dir if path.startswith(".") else self.arcadia_dir
        return os.path.join(root, path)


def resolve_value(value):
    match_result = re.findall("\\${(sec-.+:.+)}", value)
    if match_result:
        sec, key = match_result[0].split(":")
        return sdk2.yav.Secret(sec).data()[key]
    return value
