# -*- coding: utf-8 -*-

import datetime
import json
import logging
import os
import random
import socket
import time

from contextlib import closing

from sandbox import sdk2
from sandbox.sandboxsdk import errors
from sandbox.sandboxsdk import process

from sandbox.projects.common.binary_task import deprecated as binary_task
from sandbox.projects.common import network
from sandbox.projects.infra.yp_dns import resources
from sandbox.projects.websearch.params import ResourceWithLastReleasedValueByDefault


CLUSTERS = [
    'sas-test',
    'man-pre',
    'sas',
    'man',
    'vla',
    'myt',
    'iva',
    'xdc',
]

ZONES = [
    'sas-test.yp-c.yandex.net',
    'yp-test.yandex.net',
    'man-pre.yp-c.yandex.net',
    'sas.yp-c.yandex.net',
    'man.yp-c.yandex.net',
    'vla.yp-c.yandex.net',
    'myt.yp-c.yandex.net',
    'iva.yp-c.yandex.net',
    'xdc.yp-c.yandex.net',
    'gencfg-c.yandex.net',
    'in.yandex.net',
    'in.yandex-team.ru',
    'stable.qloud-b.yandex.net',
    'prestable.qloud-b.yandex.net',
    'test.qloud-b.yandex.net',
    'qloud-c.yandex.net',
    'stable.qloud-d.yandex.net',
    'prestable.qloud-d.yandex.net',
    'test.qloud-d.yandex.net',
]

ATTRIBUTES = 'ATTRIBUTES'
QUESTION = 'QUESTION'
ANSWER = 'ANSWER'
AUTHORITY = 'AUTHORITY'
ADDITIONAL = 'ADDITIONAL'

SECTIONS = [
    ATTRIBUTES,
    QUESTION,
    ANSWER,
    AUTHORITY,
    ADDITIONAL,
]

TIMED_OUT = 'Timed Out'


class YpDnsDiffTest(binary_task.LastBinaryTaskRelease, sdk2.Task):
    """
        Compare production DNS and local YP.DNS
    """

    class Parameters(sdk2.Parameters):
        ext_params = binary_task.binary_release_parameters(stable=True)

        sample_type = sdk2.parameters.RadioGroup(
            'Sample DNS type',
            choices=(
                ('Nameserver address of YP.DNS', 'nameserver_address'),
                ('Run local YP.DNS', 'local_nameserver'),
            ),
            required=True,
            default='nameserver_address',
            sub_fields={
                'nameserver_address': [
                    'sample_dns_address'
                ],
                'local_nameserver': [
                    'sample_yp_dns_components',
                    'sample_yp_dns_executable',
                    'sample_pdns_config',
                ],
            }
        )

        sample_dns_address = sdk2.parameters.String('Sample DNS address', default='ns1.yp-dns.yandex.net:53', required=True)

        with sdk2.parameters.Group('Sample YP.DNS components') as sample_yp_dns_components:
            sample_yp_dns_executable = sdk2.parameters.Resource('Sample YP.DNS executable', resource_type=resources.YpDns, required=True)
            sample_pdns_config = sdk2.parameters.Resource('Sample PowerDNS config', resource_type=resources.YpDnsPdnsConf, required=True)

        with sdk2.parameters.Group('Checked YP.DNS components') as checked_yp_dns_components:
            checked_yp_dns_executable = sdk2.parameters.Resource('Checked YP.DNS executable', resource_type=resources.YpDns, required=True)
            checked_pdns_config = sdk2.parameters.Resource('Checked PowerDNS config', resource_type=resources.YpDnsPdnsConf, required=True)

        with sdk2.parameters.Group('Diff parameters') as diff_params:
            get_queries_from_resource = sdk2.parameters.Bool('Use queries from DNS_QUERIES resource', default=False)
            with get_queries_from_resource.value[True]:
                queries = sdk2.parameters.Resource('Resource with DNS queries', resource_type=resources.DnsQueries, required=True)

            queries_percentage = sdk2.parameters.Float('Percentage of all DNS records to send (skipping option if number_of_queries is not 0)', default=0.2)
            number_of_queries = sdk2.parameters.Integer('Maximum number of queries (0 = all)', default=1000)

        yp_token_owner = sdk2.parameters.String('Owner of YP_TOKEN vault data', required=True)
        yt_token_owner = sdk2.parameters.String('Owner of YT_TOKEN vault data', required=True)

        with sdk2.parameters.Output:
            diff = sdk2.parameters.Resource('Result diff', resource_type=resources.YpDnsDiff)

    def on_create(self):
        self.Parameters.sample_yp_dns_executable = sdk2.Resource[ResourceWithLastReleasedValueByDefault(resource_type=resources.YpDns).default_value]
        self.Parameters.sample_pdns_config = sdk2.Resource[ResourceWithLastReleasedValueByDefault(resource_type=resources.YpDnsPdnsConf).default_value]

    def on_execute(self):
        super(YpDnsDiffTest, self).on_execute()

        queries_percentage = self.Parameters.queries_percentage
        number_of_queries = self.Parameters.number_of_queries
        if number_of_queries:
            queries_percentage = 1.0

        queries = self.get_queries(queries_percentage, number_of_queries)

        if not self.Parameters.get_queries_from_resource:
            self.save_queries(queries, 'dns_queries.txt')

        sample_process, sample_host, sample_port = \
            self.get_yp_dns(address=self.Parameters.sample_dns_address, yp_dns_executable=self.Parameters.sample_yp_dns_executable, pdns_config=self.Parameters.sample_pdns_config)
        checked_process, checked_host, checked_port = \
            self.get_yp_dns(yp_dns_executable=self.Parameters.checked_yp_dns_executable, pdns_config=self.Parameters.checked_pdns_config)

        addresses = [
            (sample_host, sample_port),
            (checked_host, checked_port),
        ]

        responses = {}
        responses['{}:{}'.format(*addresses[0])], responses['{}:{}'.format(*addresses[1])] = self.get_responses(addresses, queries)
        for address, resps in responses.items():
            self.save_responses(resps, 'responses_{}.txt'.format(address.replace(':', '_')))

        compare_result = self.compare_responses(queries, responses)

        diff = list(filter(lambda entry: len(entry['diff']) > 0, compare_result))
        logging.info(json.dumps(diff, indent=2))
        self.save_diff(diff, 'diff.json')

        diffs_num = self.calc_stats(compare_result)

        if diffs_num > 0:
            raise errors.SandboxTaskFailureError('Non zero number of diffs: {} diffs per {} queries. See diff in output parameters'.format(diffs_num, len(queries)))

    def get_yp_dns(self, address=None, yp_dns_executable=None, pdns_config=None):
        if yp_dns_executable is not None and pdns_config is not None:
            logging.info('Run local sample YP.DNS')
            proc, port = self.run_yp_dns(yp_dns_executable, pdns_config)
            return proc, '127.0.0.1', port

        host, port = address.split(':')
        return None, host, int(port)

    def run_yp_dns(self, yp_dns, pdns_config):
        yp_dns_path = str(sdk2.ResourceData(yp_dns).path)
        pdns_config_path = str(sdk2.ResourceData(pdns_config).path)

        port = network.get_free_port()

        work_dir = 'yp_dns_{}'.format(port)
        self.path(work_dir).mkdir()

        cmd = [
            yp_dns_path, 'run',
            '--config-dir={}'.format(os.path.dirname(pdns_config_path)),
            '--local-port={}'.format(port),
        ]

        if int(yp_dns.arcadia_revision) >= 7370324:
            cmd.append('--instance-name=sandbox-task-{}-{}'.format(self.id, port))

        environment = os.environ.copy()
        environment.update({
            'YP_TOKEN': sdk2.Vault.data(self.Parameters.yp_token_owner, 'YP_TOKEN'),
            'YT_TOKEN': sdk2.Vault.data(self.Parameters.yt_token_owner, 'YT_TOKEN'),
        })

        proc = process.run_process(cmd, wait=False, work_dir=work_dir, environment=environment, log_prefix='yp_dns_{}'.format(port))
        self._wait_start(port)

        return proc, port

    def _wait_start(self, port, timeout=120):
        start_time = time.time()
        while True:
            with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
                sock.settimeout(10)
                if sock.connect_ex(('127.0.0.1', port)) == 0:
                    break
            if time.time() - start_time > timeout:
                raise errors.SandboxTaskFailureError('Waiting for YP.DNS to start reached timeout ({}s)'.format(timeout))

    def _sample(self, values, perc=1.0, max_num=None):
        result = random.sample(values, int(len(values) * perc))
        if max_num:
            result = result[:max_num]
        return result

    def get_queries(self, queries_percentage, number_of_queries):
        queries = []

        if self.Parameters.get_queries_from_resource:
            with open(str(sdk2.ResourceData(self.Parameters.queries).path)) as f:
                queries += list(map(lambda line: line.split(), f.readlines()))
            return self._sample(queries, perc=queries_percentage, max_num=number_of_queries)

        queries += self._get_ns_queries()
        queries += self._get_soa_queries()

        environment = os.environ.copy()
        environment.update({
            'YP_TOKEN': self._get_yp_token(),
        })

        queries_from_yp = []
        for cluster in CLUSTERS:
            cluster_queries = self._get_queries_from_yp(cluster, max_queries=number_of_queries)
            queries_from_yp += self._sample(cluster_queries, perc=queries_percentage)
        queries += queries_from_yp

        return self._sample(queries, max_num=number_of_queries)

    def save_queries(self, queries, path):
        resource = resources.DnsQueries(self, 'DNS queries', path, ttl=30)
        data = sdk2.ResourceData(resource)
        with open(str(data.path), 'w') as f:
            f.writelines(map(lambda (domain, record_type): '{}\n'.format('\t'.join([domain, record_type])), queries))
        data.ready()

    def _get_ns_queries(self):
        return [(zone, 'NS') for zone in ZONES + ['nonexistent-zone.yp-c.yandex.net']]

    def _get_soa_queries(self):
        return [(zone, 'SOA') for zone in ZONES + ['nonexistent-zone.yp-c.yandex.net']]

    def _get_queries_from_yp(self, cluster, max_queries=None):
        from yp.client import YpClient

        RECORD_TYPES = set([
            'AAAA',
            'SRV',
            'PTR',
        ])

        def select(yp_client, offset, limit, record_type):
            result = yp_client.select_objects(
                "dns_record_set",
                selectors=["/meta/id", "/meta/creation_time"],
                filter="[/spec/records/0/type] = \"{}\"".format(record_type),
                limit=limit,
                offset=offset,
            )

            queries = []
            for id, creation_time in result:
                creation_date = datetime.datetime.fromtimestamp(creation_time / 1000 / 1000)
                if datetime.datetime.utcnow() - creation_date <= datetime.timedelta(minutes=1):
                    continue
                queries.append((id, record_type))

            return len(result), queries

        yp_client = YpClient(address='{}.yp.yandex.net:8090'.format(cluster), config={'token': self._get_yp_token()})
        queries = []
        for record_type in RECORD_TYPES:
            offset = 0
            limit = 1000
            while True:
                selected, new_queries = select(yp_client, offset, limit, record_type)
                if selected == 0:
                    break
                queries += new_queries
                offset += limit
                if max_queries and len(queries) >= max_queries:
                    break

        return queries

    def get_responses(self, addresses, queries):
        import retry
        import dns.exception
        from infra.yp_dns.daemon import DnsClient

        @retry.retry(tries=5, delay=2, backoff=2)
        def _query(dns_client, domain, record_type):
            return dns_client.udp(domain, record_type, timeout=5)

        assert len(addresses) == 2

        dns_clients = [DnsClient(self._get_address(*addresses[i]), addresses[i][1]) for i in range(2)]

        responses = [[], []]
        for idx, (domain, record_type) in enumerate(queries):
            if idx > 0 and idx % 100 == 0:
                time.sleep(2)

            for i in range(2):
                try:
                    resp = _query(dns_clients[i], domain, record_type).to_text()
                except dns.exception.Timeout:
                    resp = TIMED_OUT
                responses[i].append(resp)

        return responses

    def _get_address(self, hostname, port):
        return socket.getaddrinfo(hostname, port)[0][4][0]

    def save_responses(self, responses, path):
        resource = resources.YpDnsResponses(self, 'YP.DNS responses', path, ttl=30)
        data = sdk2.ResourceData(resource)
        with open(str(data.path), 'w') as f:
            json.dump(responses, f, indent=2)
        data.ready()

    def compare_responses(self, queries, responses):
        assert len(responses) == 2

        diffs = []
        for idx, (domain, record_type) in enumerate(queries):
            diff = {}
            responses_by_sections = []
            for nameserver, resps in responses.items():
                response = resps[idx]
                if response == TIMED_OUT:
                    diff[nameserver] = TIMED_OUT
                else:
                    responses_by_sections.append((nameserver, self._get_response_by_sections(response, skip=['id'])))

            if len(responses_by_sections) < 2:
                for nameserver, resps in responses.items():
                    response = resps[idx]
                    if response != TIMED_OUT:
                        diff[nameserver] = response
            else:
                first_nameserver, first_response = responses_by_sections[0]
                second_nameserver, second_response = responses_by_sections[1]
                for section_name in SECTIONS:
                    if first_response[section_name] != second_response[section_name]:
                        diff[section_name] = {
                            first_nameserver: first_response[section_name],
                            second_nameserver: second_response[section_name],
                        }

            diffs.append({
                'domain': domain,
                'type': record_type,
                'diff': diff,
            })

        return diffs

    def _get_response_by_sections(self, resp, skip=[]):
        skip = tuple(skip)
        result = {section_name: [] for section_name in SECTIONS}

        section_name = ATTRIBUTES

        for line in resp.split('\n'):
            if line.startswith(skip):
                continue

            if line[1:] in SECTIONS:
                section_name = line[1:]
                continue

            result[section_name].append(line)

        for section_name in SECTIONS:
            result[section_name].sort()

        return result

    def save_diff(self, diff, path):
        self.Parameters.diff = resources.YpDnsDiff(self, 'DNS diff', path, ttl=30)
        data = sdk2.ResourceData(self.Parameters.diff)
        with open(str(data.path), 'w') as f:
            json.dump(diff, f, indent=2)
        data.ready()

    def calc_stats(self, compare_result):
        total_diffs = sum(map(lambda compare_descr: len(compare_descr['diff']), compare_result))

        diffs_by_record_type = {}
        for record_type in ['AAAA', 'NS', 'SOA']:
            diffs_by_record_type[record_type] = sum(map(lambda compare_descr: len(compare_descr['diff']),
                                                        filter(lambda compare_descr: compare_descr['type'] == record_type, compare_result)))

        diffs_by_section = {}
        for section_name in SECTIONS:
            diffs_by_section[section_name] = sum(map(lambda compare_descr: section_name in compare_descr['diff'], compare_result))

        stats_str = ""
        stats_str += "Queries: {}\n".format(len(compare_result))
        stats_str += "Total diffs: {}\n".format(total_diffs)

        stats_str += "\nDiffs by record type:\n"
        for record_type, diffs in diffs_by_record_type.items():
            stats_str += "{}: {}\n".format(record_type, diffs)

        stats_str += "\nDiffs by section:\n"
        for section_name, diffs in diffs_by_section.items():
            stats_str += "{}: {}\n".format(section_name, diffs)

        self.set_info(stats_str)

        return total_diffs

    def _get_yp_token(self):
        return sdk2.Vault.data(self.Parameters.yp_token_owner, 'YP_TOKEN')
