import hashlib
import json
import logging
import random
import requests
import socket
import time
import subprocess

from sandbox import sdk2
from sandbox.common.types.misc import DnsType
from sandbox.projects.common import resource_selectors
from sandbox.projects.common.geosearch.base_update import generate_base_update_task_sdk2
from sandbox.projects.common.nanny.nanny import ReleaseToNannyTask2
from sandbox.projects.ydo.resource_types import YdoEmbeddingServer, YdoEmbeddingServerDssmModel, YdoServiceEmbeddings
from sandbox.projects.ydo.rubrics_merger.YdoRubricsMerger import YdoMergedRubricsSmallDump
import sandbox.common.types.task as ctt


class ServiceMatcher(ReleaseToNannyTask2, sdk2.Task):
    class Parameters(sdk2.Parameters):
        description = 'Build embeddings for service matcher'

        embedding_server_binary = sdk2.parameters.Resource('Embedding server binary', resource_type=YdoEmbeddingServer, required=True)
        dssm = sdk2.parameters.Resource('Dssm model for embeddings', resource_type=YdoEmbeddingServerDssmModel, required=True)

    def start_embedding_server(self):
        def is_port_available(port):
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            try:
                sock.bind(('0.0.0.0', port))
                result = True
            except:
                result = False
            sock.close()
            return result

        def find_free_port():
            for _ in range(10000):
                port = random.randint(10000, 65535)
                if is_port_available(port):
                    return port
            raise Exception('Failed to find free port')

        embedding_server_path = str(sdk2.ResourceData(self.Parameters.embedding_server_binary).path)
        dssm_path = str(sdk2.ResourceData(self.Parameters.dssm).path)
        port = find_free_port()

        logging.debug('Starting embedding server process')
        sub = subprocess.Popen([embedding_server_path, '--dssm-path', dssm_path, '--port', str(port)])

        url = 'http://localhost:{}/dssm'.format(port)

        max_wait_steps = 180
        for i in range(max_wait_steps):
            # noinspection PyBroadException
            try:
                requests.get(url, params={'text': 'test'}).raise_for_status()
            except:
                time.sleep(1)
                if i + 1 == max_wait_steps:
                    raise ValueError('Failed to start embedding server')
                continue
            break
        logging.debug('Started embedding server process')

        return sub, url

    def get_embedding(self, url, text):
        response = requests.get(url, params={'text': text})
        response.raise_for_status()
        return [float(v) for v in response.text.split()]

    def on_execute(self):
        service_embeddings_resource = sdk2.ResourceData(YdoServiceEmbeddings(self, 'Service Embeddings', 'service_embeddings.json'))

        resource = sdk2.Resource.find(
            YdoMergedRubricsSmallDump,
            id=resource_selectors.by_last_released_task(YdoMergedRubricsSmallDump, stage=ctt.ReleaseStatus.STABLE)[0]
        ).first()

        with open(str(sdk2.ResourceData(resource).path)) as file:
            rubrics = json.load(file)

        sub, url = self.start_embedding_server()

        service_embeddings = {}
        for rubric in rubrics.values():
            service_embeddings[rubric['id'].encode('utf-8')] = self.get_embedding(url, rubric['name'])

        with open('service_embeddings.json', 'w') as file:
            json.dump(service_embeddings, file, ensure_ascii=False)

        service_embeddings_resource.ready()

        sub.kill()
        sub.wait()


BaseUpdate = generate_base_update_task_sdk2(ServiceMatcher, release_subject="Update ServiceMatcher", force_rebuild=False)


class UpdateServiceMatcher(BaseUpdate):
    """
        Update ServiceMatcher
    """

    class Requirements(BaseUpdate.Requirements):
        environments = [
            sdk2.environments.PipEnvironment('yandex-yt'),
        ]
        dns = DnsType.DNS64
        cores = 1

        class Caches(BaseUpdate.Requirements.Caches):
            pass

    def _check_need_update(self):
        # noinspection PyProtectedMember
        need_update = super(UpdateServiceMatcher, self)._check_need_update()
        if need_update:
            logging.debug('need_update already set, skipping hash check')
            return need_update

        resource = sdk2.Resource.find(
            YdoMergedRubricsSmallDump,
            id=resource_selectors.by_last_released_task(YdoMergedRubricsSmallDump, stage=ctt.ReleaseStatus.STABLE)[0]
        ).first()
        rubrics_path = str(sdk2.ResourceData(resource).path)

        with open(rubrics_path) as file:
            rubrics_text = json.dumps(json.load(file), ensure_ascii=True, sort_keys=True)
            rubrics_hash = hashlib.md5(bytes(str(rubrics_text).encode('utf-8'))).digest()
            logging.debug('Calculated hash for rubrics dump: %s', rubrics_hash)

        import yt.wrapper as yt

        yt.config['token'] = sdk2.Vault.data(self.owner, 'yt-token')
        yt.config['proxy']['url'] = 'locke.yt.yandex.net'

        table_path = '//home/ydo/sandbox/{}/ServiceMatcher'.format(ctt.ReleaseStatus.STABLE)
        yt.create('table', table_path, attributes={'rubrics_hash': ''}, ignore_existing=True, recursive=True)
        hash_attr_path = table_path + '/@rubrics_hash'
        stored_hash = yt.get(hash_attr_path)
        logging.debug('Stored hash: %s', stored_hash)

        if stored_hash != rubrics_hash:
            self.Context.hash_changed = True
            self.Context.hash_attr_path = hash_attr_path
            self.Context.rubrics_hash = rubrics_hash
            return True
        self.Context.hash_changed = False
        return False

    def on_execute(self):
        super(UpdateServiceMatcher, self).on_execute()

        if not self.Context.hash_changed:
            logging.debug('Hash has not changed or not checked, skip saving hash')
            return

        import yt.wrapper as yt

        yt.config['token'] = sdk2.Vault.data(self.owner, 'yt-token')
        yt.config['proxy']['url'] = 'locke.yt.yandex.net'

        yt.set(self.Context.hash_attr_path, self.Context.rubrics_hash)
