# -*- coding: utf-8 -*-

import logging
import os
import time

from collections import defaultdict

from sandbox.projects.common import error_handlers as eh
from .proto_transport import ProtoTransport


GROUP_ID = 'test_group_id'
USER_ID = 'test_user_id'


class UnexpectedResponse:
    pass


def bytes_from_file(filename, chunksize=8192):
    with open(filename, "rb") as f:
        while True:
            chunk = f.read(chunksize)
            if chunk:
                yield chunk
            else:
                break


def maybe_join(base_path, relpath):
    if relpath is None:
        return None
    return os.path.join(base_path, relpath)


class SanityBioJob:
    def __init__(self, cfg, base_dir):
        self.name = cfg.get('name', 'unknown')
        self.protocol = cfg.get('protocol', 'proto_classify')
        self.audio_spotter_file = cfg.get('audio_spotter_file')
        self.audio_file = cfg.get('audio_file')
        self.audio_mime = cfg.get('audio_mime')
        if self.protocol in ('proto_classify', 'proto_score'):
            if not self.audio_file:
                eh.check_failed('sanity job MUST have "audio_file" option')
            if not self.audio_mime:
                eh.check_failed('sanity job MUST have "audio_mime" option')
            self.audio_file = str(os.path.join(base_dir, self.audio_file))
            if self.audio_spotter_file:
                self.audio_spotter_file = str(os.path.join(base_dir, self.audio_spotter_file))

        self.classification_tags = cfg.get('classification_tags', [])
        self.enroll_files = cfg.get('enroll_files')

        if self.protocol == 'proto_score':
            if not self.enroll_files:
                eh.check_failed('sanity job MUST have "enroll_files" option')
            self.enroll_files = [
                (maybe_join(base_dir, spotter_file), maybe_join(base_dir, file))
                for (spotter_file, file) in self.enroll_files
            ]

        self.chunk_size = cfg.get('chunk_size', 2456)
        self.chunk_duration = cfg.get('chunk_duration', 0)
        self.uri_path = cfg.get('uri_path', '/bio')
        self.use_temp_lingware = cfg.get('use_temp_lingware', False)
        self.timeout = cfg.get('timeout', 30)
        self.expected_result = cfg.get('expected_result')

    def process(self, server_port, server_host='localhost'):
        from voicetech.library.proto_api.yabio_pb2 import Classify as YabioClassify
        from voicetech.library.proto_api.yabio_pb2 import Score as YabioScore
        from voicetech.library.proto_api.yabio_pb2 import YabioContext
        from voicetech.library.proto_api.yabio_pb2 import YabioUser

        logging.debug('process sanity job name={}'.format(self.name))
        self.server_host = server_host
        self.server_port = server_port
        self.state = None

        if self.protocol == 'proto_classify':
            self.run_protobuf(
                method=YabioClassify,
                audio_spotter_file=self.audio_spotter_file,
                audio_file=self.audio_file,
                result_checker=self._check_classification)
        elif self.protocol == 'proto_score':
            # get enrollings
            voiceprints = []
            for audio_spotter_file, audio_file in self.enroll_files:
                result = [None]
                self.run_protobuf(
                    method=YabioScore,
                    result_checker=self._get_score_result_saver(result),
                    audio_spotter_file=audio_spotter_file,
                    audio_file=audio_file,
                    context=YabioContext(group_id=GROUP_ID)
                )
                voiceprints.extend(result[0].context.enrolling)
            logging.debug("Have {} enrolls".format(len(voiceprints)))
            self.run_protobuf(
                method=YabioScore,
                result_checker=self._check_score,
                audio_spotter_file=self.audio_spotter_file,
                audio_file=self.audio_file,
                context=YabioContext(
                    group_id=GROUP_ID,
                    users=[
                        YabioUser(
                            user_id=USER_ID,
                            voiceprints=voiceprints
                        ),
                    ]
                )
            )

        else:
            eh.check_failed('unsupported job protocol={}'.format(self.protocol))
        logging.info('test [{}] is ok'.format(self.name))

    def check_failed(self, msg):
        eh.check_failed('FAIL on job <{}>: {}'.format(self.name, msg))

    def check_assert(self, condition, msg=""):
        if not condition:
            self.check_failed(msg)

    def run_protobuf(self, method, result_checker, audio_spotter_file, audio_file, context=None):
        # import voicetech api
        # use local import as recommended https://clubs.at.yandex-team.ru/arcadia/16437
        from voicetech.library.proto_api.yabio_pb2 import YabioRequest
        from voicetech.library.proto_api.yabio_pb2 import YabioResponse
        from voicetech.library.proto_api.yabio_pb2 import AddData
        from voicetech.library.proto_api.yabio_pb2 import AddDataResponse

        with ProtoTransport(self.server_host, self.server_port, timeout=self.timeout) as t:
            t.verbose = False
            logging.debug('connected, sending upgrade request')
            t.socket.send('GET {} HTTP/1.1\r\nHost: {}\r\nUpgrade: protobuf\r\n\r\n'.format(self.uri_path, self.server_host))
            resp101 = t.socket.recv(1024)
            logging.debug('receive:\n' + resp101)
            if 'HTTP/1.1 101 ' not in resp101:
                self.check_failed('expected "HTTP/1.1 101 " in response, but has={}'.format(resp101))
            logging.debug('upgraded to protobuf')
            init_request = YabioRequest(
                hostName='localhost',
                sessionId='test_session_id',
                uuid='test_uuid',
                group_id=GROUP_ID,
                user_id=USER_ID,
                mime=self.audio_mime,
                context=context,
                method=method,
                classification_tags=self.classification_tags,
                spotter=self.audio_spotter_file is not None
            )
            logging.debug('sending initial request:\n{}'.format(init_request))
            t.sendProtobuf(init_request)

            resp = t.recvProtobuf(YabioResponse)
            logging.debug('connection is inited, response:\n{}'.format(resp))

            if resp.responseCode != 200:
                self.check_failed('bad responseCode={}'.format(resp.responseCode))
            logging.debug('response hostname={}'.format(resp.hostname))

            sendCount = 0  # sum(AddData)
            needCount = 0  # AddData with needResult=True count or last chunk
            processedCount = 0  # sum(AddDataResponse.messagesCount)
            recvCount = 0  # AddDataResponse

            start_time = time.time()
            use_spotter = bool(self.audio_spotter_file)
            if self.audio_spotter_file:
                logging.debug('begin transmit audio from file={}'.format(self.audio_spotter_file))
                for chunk in bytes_from_file(audio_spotter_file, self.chunk_size):
                    sendCount += 1
                    needResult = bool(sendCount % 2)
                    if needResult:
                        needCount += 1
                    logging.debug('about to send {} bytes'.format(len(chunk)))
                    t.sendProtobuf(AddData(
                        audioData=chunk,
                        lastChunk=False,
                        needResult=needResult,
                    ))
                sendCount += 1
                needCount += 1
                logging.debug("Send Last Spotter Chunk")
                t.sendProtobuf(AddData(needResult=True, lastChunk=False, lastSpotterChunk=True))

            while recvCount < needCount:
                logging.debug("Recv spotter result")
                response = t.recvProtobuf(AddDataResponse)
                self.check_assert(response.responseCode == 200)
                recvCount += 1
                processedCount += response.messagesCount
                result = response
                result_checker(result, use_spotter=use_spotter, is_final=False)
                logging.debug(result)

            logging.debug('begin transmit audio from file={}'.format(self.audio_file))
            for chunk in bytes_from_file(audio_file, self.chunk_size):
                sendCount += 1
                logging.debug('about to send {} bytes'.format(len(chunk)))
                needResult = bool(sendCount % 2)

                t.sendProtobuf(AddData(
                    audioData=chunk,
                    lastChunk=False,
                    needResult=needResult
                ))

                if needResult:
                    needCount += 1

                logging.debug('chunk is sent')
                if self.chunk_duration:
                    logging.debug('suspend next chunk for {} ms'.format(self.chunk_duration))
                    time.sleep(self.chunk_duration / 1000.)

            logging.debug('send last chunk')
            t.sendProtobuf(AddData(lastChunk=True, needResult=True))
            sendCount += 1
            needCount += 1

            while recvCount < needCount:
                logging.debug("Recv result")
                response = t.recvProtobuf(AddDataResponse)
                self.check_assert(response.responseCode == 200)
                logging.debug('response messageCount={}'.format(response.messagesCount))
                recvCount += 1
                processedCount += response.messagesCount
                result = response
                result_checker(result, use_spotter=use_spotter, is_final=False)
                logging.debug(result)

            logging.debug('send chunks={} processed chunks={}'.format(sendCount, processedCount))
            assert sendCount == processedCount
            assert needCount == recvCount
            logging.info("duration={}".format(time.time() - start_time))
            result_checker(result, use_spotter=use_spotter, is_final=True)

            try:
                t.recvProtobuf(AddDataResponse)
                raise UnexpectedResponse
            except UnexpectedResponse:
                raise Exception('got extra/unexpected AddDataResponse')
            except Exception:
                logging.debug('catch expected error when try read extra AddDataResponse')
                pass

    def _check_basic(self, result, use_spotter=False, is_final=False):
        if self.state is None:
            self.state = []
        self.state.append(result)
        if not is_final:
            return

        self.check_assert(len(self.state) >= 2, "There should be at least one partial and one final result")

    def _check_classification(self, result, use_spotter=False, is_final=False):
        self._check_basic(result, use_spotter=use_spotter, is_final=is_final)
        if not is_final:
            return

        def check_result(some_result):
            self.check_assert(some_result.classification, "No classification result for classification request")

            some_tag_to_result = {r.tag: r.classname for r in some_result.classificationResults}
            some_result = some_result.classification
            self.check_assert(set(some_tag_to_result) == set(self.classification_tags))

            # check that all classification tags are present and all present tags are requested
            present_tags = set()
            tag_to_classname_to_confidence = defaultdict(dict)
            for item in some_result:
                present_tags.add(item.tag)
                classname_to_confidence = tag_to_classname_to_confidence[item.tag]
                self.check_assert(
                    item.classname not in classname_to_confidence, "Dublicate classname {}".format(item.classname)
                )
                classname_to_confidence[item.classname] = item.confidence
            tag_to_classname_to_confidence = dict(tag_to_classname_to_confidence)

            self.check_assert(present_tags == set(self.classification_tags))

            # skip is for, e. g., broken files
            if not self.expected_result.get('skip'):
                for tag, classname_to_confidence in tag_to_classname_to_confidence.items():
                    if self.expected_result[tag] is None:
                        continue
                    max_class = some_tag_to_result[tag]
                    self.check_assert(max_class == self.expected_result[tag])

        # checking final
        check_result(self.state[-1])
        # last partial should also be good
        check_result(self.state[-2])

    def _check_context(self, result, use_spotter=False, is_final=False):
        self.check_assert(result.context)
        self.check_assert(result.context.group_id == GROUP_ID)
        # num_lingwares = (1 + int(self.use_temp_lingware))
        # self.check_assert(len(result.context.enrolling) == num_lingwares * (1 + int(use_spotter)))

        self.check_assert(result.context.enrolling[0].compatibility_tag)
        self.check_assert(result.context.enrolling[0].format)
        self.check_assert(result.context.enrolling[0].source)
        self.check_assert(len(result.context.enrolling[0].voiceprint) > 0)

    def _check_score(self, result, use_spotter=False, is_final=False):
        self._check_basic(result, use_spotter=use_spotter, is_final=is_final)
        self._check_context(result, use_spotter=use_spotter, is_final=is_final)

        self.check_assert(len(set(result.supported_tags)) == len(result.supported_tags))

        # 2 because one spotter and one request
        self.check_assert(len(set(result.supported_tags)) == (1 + self.use_temp_lingware) * 2)

        if not is_final:
            return

        def check_result(some_result):
            for mode_score in some_result.scores_with_mode:
                # one user => one score
                self.check_assert(len(mode_score.scores) == 1)

                # sum for possible future with many users
                scores_sum = sum(map(lambda x: x.score, mode_score.scores))

                self.check_assert(0.0 <= scores_sum <= 1.0)
                self.check_assert(all(map(lambda x: x.user_id == USER_ID, mode_score.scores)))

                # skip is for, e. g., broken files
                if not self.expected_result.get('skip'):
                    is_guest = (scores_sum < 0.5)
                    self.check_assert(is_guest == self.expected_result['is_guest'])

        # checking final
        check_result(self.state[-1])
        # last partial should also be good
        check_result(self.state[-2])

    def _get_score_result_saver(self, destination):
        assert len(destination) == 1
        assert destination[0] is None

        def score_result_saver(result, use_spotter=False, is_final=False):
            self._check_context(result, use_spotter=use_spotter, is_final=is_final)
            if is_final:
                destination[0] = result

        return score_result_saver
