from sandbox import sdk2
from sandbox.projects.common import dolbilka2
from sandbox.projects.collections.recommender_base.shooting_task2 import ShootingTask2
from sandbox.projects.dj.services.entity import helpers
from sandbox.projects.dj.services.entity.components import EntityRecommender
from sandbox.projects.dj.services.entity.helpers import TemporaryDirectory
from sandbox.projects.dj.services.entity.resources import EntityRecommenderShootingPlan
from sandbox.projects.dj.services.entity.resources import EntityRecommenderBundle
from sandbox.projects.dj.services.entity.resources import EntityRecommenderShard
from sandbox.projects.dj.services.entity.resources import EntityRecommenderShard2
from sandbox.projects.dj.services.entity.resources import EntityRecommenderArcadiaModels
from sandbox.projects.resource_types import TASK_CUSTOM_LOGS
from sandbox.projects.tank import executor2 as tank_executor
from sandbox.sandboxsdk import environments
from sandbox.sdk2 import Task
from sandbox.sdk2 import ResourceData
from sandbox.sdk2 import Vault
from sandbox.sdk2.helpers import ProcessLog
from sandbox.sdk2.parameters import Integer
from sandbox.sdk2.parameters import Bool
from sandbox.sdk2.parameters import Dict
from sandbox.sdk2.parameters import String
from sandbox.sdk2.parameters import Resource
from sandbox.sdk2.parameters import List

import os
import json


ACCESS_LOG_MAX_ROWS_PER_YT_WRITE_REQUEST = 1024


class EntityRecommenderPerfTestParamenters(Task.Parameters):
    http_port = Integer("Http Port", default=15358, required=True)
    apphost_port = Integer("AppHost Port", default=15359, required=True)
    use_apphost = Bool("Use apphost port instead of http", default=False)
    patch = String('Text patch for config (has priority over bundle patches)')
    bundle_patches = List('Patches for config from bundle', default=[])
    specify_plan = Bool('Specify plan', default=False)
    with specify_plan.value[True]:
        shooting_plan = Resource("Resource with dolbilka plan", resource_type=EntityRecommenderShootingPlan, required=True)
    with specify_plan.value[False]:
        plan_attributes = Dict("Required attribute for plan resource", default=dict())
    bundle_resource = Resource("Resource with Entity Recommender bundle", resource_type=EntityRecommenderBundle)
    shard_resource = Resource("Resource with Entity Recommender shard", resource_type=EntityRecommenderShard)
    shard2_resource = Resource("Resource with Entity Recommender shard 2.0", resource_type=EntityRecommenderShard2)
    models_resource = Resource("Resource with Entity Recommender models", resource_type=EntityRecommenderArcadiaModels)
    enable_realtime_update = Bool("Enable realtime update", default=False)
    write_access_log = Bool('Write access log', default=False)
    with write_access_log.value[True]:
        yt_token_vault = String('YT Token')
        access_log_yt_cluster = String('YT Cluster', default='hahn')
        access_log_yt_table = String('Access log table')
    verbose_level = Integer("Verbose level", default=0)
    dolbilka_param = dolbilka2.DolbilkaExecutor2.Parameters
    lunapark_param = tank_executor.LunaparkPlugin.Parameters
    offline_param = tank_executor.OfflinePlugin.Parameters


class EntityRecommenderPerfTestRequirements(Task.Requirements):
    ram = 10 * 1024
    disk_space = 10 * 1024
    client_tags = ShootingTask2.client_tags
    cores = 4
    environments = (
        environments.PipEnvironment('yandex-yt'),
        environments.PipEnvironment("yandex-yt-yson-bindings-skynet")
    )


class EntityRecommenderPerfTestContext(Task.Context):
    bundle_id = None
    shard_id = None
    shard2_id = None
    models_id = None
    shooting_plan_id = None


class EntityRecommenderPerfTest(Task, ShootingTask2):
    """ Task for testing base Entity Recommender performance """

    class Parameters(EntityRecommenderPerfTestParamenters):
        pass

    class Requirements(EntityRecommenderPerfTestRequirements):
        pass

    class Context(EntityRecommenderPerfTestContext):
        pass

    def on_execute(self):
        with TemporaryDirectory(dir=os.getcwd()) as working_dir, \
                ProcessLog(self, logger="entity_recommender") as process_log:
            self._process(working_dir, process_log)

    def _process(self, working_dir, process_log):
        # Aliases
        parameters = self.Parameters
        context = self.Context
        # Resources
        shooting_plan = helpers.find_shooting_plan_resource(
            explicit_resource=parameters.shooting_plan,
            plan_attributes=parameters.plan_attributes)
        shard = helpers.find_shard_resource(parameters.shard_resource)
        shard2 = helpers.find_shard2_resource(parameters.shard2_resource)
        bundle = helpers.find_bundle_resource(parameters.bundle_resource)
        models = helpers.find_models_resource(parameters.models_resource)
        # Update context
        context.shooting_plan_id = shooting_plan.id
        context.bundle_id = bundle.id
        context.shard_id = shard.id
        context.shard2_id = shard2.id
        context.models_id = models.id
        # Download resource data
        bundle_data = ResourceData(bundle)
        shard_data = ResourceData(shard)
        shard2_data = ResourceData(shard2)
        models_data = ResourceData(models)
        # Prepare resource data
        helpers.prepare_entity_recommender_resources(bundle_data, shard_data, shard2_data, models_data, working_dir)
        # Init paths
        access_log_data = None
        access_log_path = None
        if parameters.write_access_log:
            access_log_data = ResourceData(TASK_CUSTOM_LOGS(self, "access log", "access_log.txt"))
            access_log_path = helpers.resource_data_path(access_log_data)
        realtime_update_config_path = None
        if parameters.enable_realtime_update:
            realtime_update_config_path = os.path.join(working_dir, "realtime_update_config.pbtxt")
        # Launch and shoot
        with helpers.launch_entity_recommender(
                binary_path=os.path.join(working_dir, "entity_recommender_service"),
                http_port=parameters.http_port,
                apphost_port=parameters.apphost_port,
                config_path=os.path.join(working_dir, "config.pbtxt"),
                patch_paths=helpers.prepare_patches(parameters.bundle_patches, parameters.patch, working_dir),
                working_dir_path=working_dir,
                realtime_update_config_path=realtime_update_config_path,
                process_log=process_log,
                access_log_path=access_log_path,
                verbose=parameters.verbose_level):
            self._init_virtualenv()
            self._dolbilo_shoot(parameters.apphost_port if parameters.use_apphost else parameters.http_port, shooting_plan, "perf_test")
        # Write logs to yt
        if parameters.write_access_log:
            self._write_logs_to_yt(access_log_path)

    def _write_logs_to_yt(self, path):
        import yt.wrapper as yt
        parameters = self.Parameters
        output_table = parameters.access_log_yt_table
        if not output_table:
            return
        yt.config['proxy']['url'] = parameters.access_log_yt_cluster
        yt.config['token'] = Vault.data(parameters.yt_token_vault)
        yt.config['pickling']['python_binary'] = '/skynet/python/bin/python'

        yt.create("table", output_table, recursive=True, force=True)
        with open(path) as log_file:
            rows_batch = []
            for row in log_file:
                rows_batch.append(json.loads(row))
                if len(rows_batch) >= ACCESS_LOG_MAX_ROWS_PER_YT_WRITE_REQUEST:
                    yt.write_table(yt.TablePath(output_table, append=True), rows_batch)
                    rows_batch = []
            else:
                if rows_batch:
                    yt.write_table(yt.TablePath(output_table, append=True), rows_batch)
        yt.run_merge(output_table, output_table, spec=dict(combine_chunks=True))

    @sdk2.footer()
    def footer(self):
        return ShootingTask2.footer(self)
