# -*- coding: utf-8 -*-

import getpass
import itertools
import json
import os
import random
import requests
import socket
import time

from sandbox import sdk2

from sandbox.sandboxsdk import errors
from sandbox.sandboxsdk import process

from sandbox.projects.common.binary_task import deprecated as binary_task
from sandbox.projects.infra.yp_service_discovery import resources


CLUSTERS = [
    'sas-test',
    'man-pre',
    'sas',
    'man',
    'vla',
    'myt',
    'iva',
    'xdc',
]


class YpServiceDiscoveryDiffTest(binary_task.LastBinaryTaskRelease, sdk2.Task):
    """
        Compare YP Service Discovery and YP master
    """

    CLIENT_NAME = 'YpServiceDiscoveryDiffTest:{host}:{user}'

    class Parameters(sdk2.Parameters):
        ext_params = binary_task.binary_release_parameters(stable=True)

        yp_service_discovery = sdk2.parameters.Resource('Checked YP Service Discovery executable', resource_type=resources.YpServiceDiscovery)

        requests_percentage = sdk2.parameters.Float('Percentage of all endpoint sets to send', default=0.2, required=True)

        yp_token_owner = sdk2.parameters.String('Owner of YP_TOKEN vault data', required=True)

        with sdk2.parameters.Output:
            diff = sdk2.parameters.Resource('Result diff', resource_type=resources.YpServiceDiscoveryDiff)

    def on_execute(self):
        super(YpServiceDiscoveryDiffTest, self).on_execute()

        endpoint_sets = self.get_requests()

        yp_sd_process, yp_sd_discovery_http_port, yp_sd_grpc_port, yp_sd_admin_http_port = self.run_yp_service_discovery(self.Parameters.yp_service_discovery)
        responses = self.get_responses(yp_sd_grpc_port, endpoint_sets)
        self.stop_yp_service_discovery(yp_sd_process, yp_sd_admin_http_port)

        compare_result = self.compare_with_master(responses)
        assert len(compare_result) == len(endpoint_sets)

        diff = list(filter(lambda entry: len(entry['diff']) > 0, compare_result))
        self.save_diff(diff, 'diff.json')

        self.print_stats(compare_result, diff)

        if len(diff) > 0:
            raise errors.SandboxTaskFailureError('Non zero number of diffs: {}/{} (diffs/total). See diff in output parameters'.format(len(diff), len(endpoint_sets)))

    def get_requests(self):
        from yp.client import YpClient

        result = []

        def select(yp_client, offset, limit):
            return yp_client.select_objects(
                "endpoint_set",
                selectors=["/meta/id"],
                limit=limit,
                offset=offset,
            )

        for cluster in CLUSTERS:
            yp_client = YpClient(address='{}.yp.yandex.net:8090'.format(cluster), config={'token': self._get_yp_token()})

            endpoint_sets = []

            offset = 0
            limit = 1000
            while True:
                selected = select(yp_client, offset, limit)
                endpoint_sets += list(itertools.product([cluster], map(lambda esid: esid[0], selected)))
                if len(selected) < limit:
                    break
                offset += limit

            result += random.sample(endpoint_sets, int(len(endpoint_sets) * self.Parameters.requests_percentage))

        return result

    def run_yp_service_discovery(self, yp_service_discovery):
        yp_service_discovery_path = str(sdk2.ResourceData(yp_service_discovery).path)

        discovery_http_port = 8080
        grpc_port = 8081
        admin_http_port = 8082

        cmd = [
            yp_service_discovery_path, 'run',
            '-V', 'DiscoveryHttpServiceConfig.Port={}'.format(discovery_http_port),
            '-V', 'GrpcServiceConfig.Port={}'.format(grpc_port),
            '-V', 'AdminHttpServiceConfig.Port={}'.format(admin_http_port),
        ]

        environment = os.environ.copy()
        environment.update({
            'YP_TOKEN': sdk2.Vault.data(self.Parameters.yp_token_owner, 'YP_TOKEN')
        })

        proc = process.run_process(cmd, wait=False, environment=environment, log_prefix='yp_service_discovery_{}'.format(discovery_http_port))
        self._wait_start(discovery_http_port)

        return proc, discovery_http_port, grpc_port, admin_http_port

    def stop_yp_service_discovery(self, proc, admin_http_port):
        try:
            r = requests.get('http://127.0.0.1:{}/shutdown'.format(admin_http_port), timeout=120)
            r.raise_for_status()
            proc.wait(timeout=120)
        except:
            pass
        proc.kill()

    def _wait_start(self, discovery_http_port, timeout=600):
        start_time = time.time()
        while True:
            try:
                r = requests.get('http://127.0.0.1:{}/ping'.format(discovery_http_port))
                r.raise_for_status()
                if r.status_code == 200:
                    break
            except:
                time.sleep(1)
                pass

            if time.time() - start_time > timeout:
                raise errors.SandboxTaskFailureError('Waiting for YP Service Discovery to start reached timeout ({}s)'.format(timeout))

    def _get_client_name(self):
        return self.CLIENT_NAME.format(host=socket.gethostname(), user=getpass.getuser())

    def get_responses(self, grpc_port, endpoint_sets):
        from infra.yp_service_discovery.python.resolver.resolver import Resolver
        from infra.yp_service_discovery.api import api_pb2

        resolver = Resolver(client_name=self._get_client_name(), grpc_address='127.0.0.1:{}'.format(grpc_port))

        responses = []
        for cluster, esid in endpoint_sets:
            request = api_pb2.TReqResolveEndpoints()
            request.cluster_name = cluster
            request.endpoint_set_id = esid
            responses.append((cluster, esid, resolver.resolve_endpoints(request)))

        return responses

    def compare_with_master(self, responses):
        from yp.client import YpClient
        import google.protobuf.json_format as json_format

        yp_clients = {
            cluster: YpClient(address='{}.yp.yandex.net:8090'.format(cluster), config={'token': self._get_yp_token()})
            for cluster in CLUSTERS
        }

        compare_results = []
        for cluster, esid, response in responses:
            yp_client = yp_clients.get(cluster)
            es_selection_result = yp_client.get_object(
                "endpoint_set",
                esid,
                selectors=["/meta/id"],
                timestamp=response.timestamp,
                options={"ignore_nonexistent": True},
            )
            selection_result = yp_client.select_objects(
                "endpoint",
                selectors=["/meta/id", "/meta/endpoint_set_id", "/spec", "/status"],
                filter="[/meta/endpoint_set_id] = '{}'".format(esid),
                timestamp=response.timestamp,
            )

            master_endpoint_set = self._make_proto_from_selection_result(es_selection_result, selection_result)

            self._sort_endpoints(response.endpoint_set.endpoints)
            self._sort_endpoints(master_endpoint_set.endpoints)

            compare_result = {
                'cluster': cluster,
                'esid': esid,
                'diff': {},
            }

            if response.endpoint_set != master_endpoint_set:
                compare_result['diff'].update({
                    'checked': json_format.MessageToDict(response.endpoint_set),
                    'master': json_format.MessageToDict(master_endpoint_set),
                })

            compare_results.append(compare_result)

        return compare_results

    def _make_proto_from_selection_result(self, es_selection_result, selection_result):
        from infra.yp_service_discovery.api import api_pb2

        endpoint_set = api_pb2.TEndpointSet()

        if es_selection_result is not None:
            endpoint_set.endpoint_set_id = es_selection_result[0]

        for meta_id, esid, spec, status in selection_result:
            endpoint = endpoint_set.endpoints.add()
            endpoint.id = meta_id
            if "protocol" in spec:
                endpoint.protocol = spec["protocol"]
            if "fqdn" in spec:
                endpoint.fqdn = spec["fqdn"]
            if "ip4_address" in spec:
                endpoint.ip4_address = spec["ip4_address"]
            if "ip6_address" in spec:
                endpoint.ip6_address = spec["ip6_address"]
            if "port" in spec:
                endpoint.port = spec["port"]
            if status and status.get('ready', False):
                endpoint.ready = True

        return endpoint_set

    def _sort_endpoints(self, endpoints):
        endpoints.sort(key=lambda endpoint: endpoint.id)

    def _get_yp_token(self):
        return sdk2.Vault.data(self.Parameters.yp_token_owner, 'YP_TOKEN')

    def save_diff(self, diff, path):
        self.Parameters.diff = resources.YpServiceDiscoveryDiff(self, 'YP Service Discovery diff', path, ttl=30)
        data = sdk2.ResourceData(self.Parameters.diff)
        with open(str(data.path), 'w') as f:
            json.dump(diff, f, indent=2)
        data.ready()

    def print_stats(self, compare_result, diff):
        stats_str = ""
        stats_str += "Requests: {}\n".format(len(compare_result))

        stats_str += "\nRequests by cluster:\n"
        for cluster in CLUSTERS:
            stats_str += "{}: {}\n".format(cluster, len(list(filter(lambda entry: entry['cluster'] == cluster, compare_result))))

        stats_str += "\nTotal diffs: {}\n".format(len(diff))
        stats_str += "\nDiffs by cluster:\n"
        for cluster in CLUSTERS:
            stats_str += "{}: {}\n".format(cluster, len(list(filter(lambda entry: entry['cluster'] == cluster, diff))))

        self.set_info(stats_str)
