# -*- coding: utf-8 -*-

import httplib
import json
import logging
import os
import time
import urllib2
import xml.etree.ElementTree as etree
from contextlib import closing

from sandbox.projects.common import error_handlers as eh
from .proto_transport import ProtoTransport


def bytes_from_file(filename, chunksize=8192):
    with open(filename, "rb") as f:
        while True:
            chunk = f.read(chunksize)
            if chunk:
                yield chunk
            else:
                break


class SanityRecognitionJob:

    def __init__(self, cfg, base_dir):
        self.name = cfg.get('name', 'unknown')
        self.protocol = cfg.get('protocol', 'http')
        self.audio_spotter_file = cfg.get('audio_spotter_file')
        self.audio_file = cfg.get('audio_file')
        self.audio_mime = cfg.get('audio_mime')
        if self.protocol in ('http', 'proto'):
            if not self.audio_file:
                eh.check_failed('sanity job MUST have "audio_file" option')
            if not self.audio_mime:
                eh.check_failed('sanity job MUST have "audio_mime" option')
            self.audio_file = str(os.path.join(base_dir, self.audio_file))
            if self.audio_spotter_file:
                self.audio_spotter_file = str(os.path.join(base_dir, self.audio_spotter_file))
        self.antimat = cfg.get('antimat', True)
        self.enable_e2e_eou = cfg.get('enable_e2e_eou')
        self.partial_update_period = cfg.get('partial_update_period')
        self.chunk_size = cfg.get('chunk_size', 2000)
        self.chunk_duration = cfg.get('chunk_duration', 0)
        self.spotter_back = cfg.get('spotter_back')
        self.request_front = cfg.get('request_front')
        self.spotter_validation = cfg.get('spotter_validation')
        self.spotter_phrase = cfg.get('spotter_phrase')
        self.punctuation = cfg.get('punctuation', False)
        self.topic = cfg.get('topic', 'topic')
        self.expect_result = cfg.get('expect_result')
        self.uri_path = cfg.get('uri_path', '/asr')
        self.expect_error = cfg.get('expect_error')
        self.expect_normalized = cfg.get('expect_normalized')
        self.expect_partial_normalized = cfg.get('expect_partial_normalized')
        self.expect_eou_between = cfg.get('expect_eou_between')
        self.use_trash_talk_classifier = cfg.get('use_trash_talk_classifier')
        self.timeout = cfg.get('timeout')
        self.context = cfg.get('context')
        self.language = cfg.get('language', 'ru')
        self.revnorm_language = cfg.get('revnorm_language', 'ru-RU')
        self.force_process_each_chunk = cfg.get('force_process_each_chunk', False)
        self.enable_degradation = cfg.get('enable_degradation', False)
        self.always_fill_response = cfg.get('always_fill_response', False)
        self.expect_raw_recv_count = cfg.get('expect_raw_recv_count', None)

        self.experiments = {
            "ASR": {
                "flags": cfg.get('ab_flags', []),
                "boxes": []
            }
        }

    def process(self, server_port, server_host='localhost'):
        logging.debug('process sanity job name={}'.format(self.name))
        self.server_host = server_host
        self.server_port = server_port
        if self.protocol == 'http':
            self.run_http_recognition()
        elif self.protocol == 'proto':
            self.run_protobuf_recognition()
        elif self.protocol == 'ping':
            self.run_ping()
        elif self.protocol == 'info':
            self.run_info()
        elif self.protocol == 'unistat':
            self.run_unistat()
        else:
            eh.check_failed('unsupported job protocol={}'.format(self.protocol))
        logging.info('test [{}] is ok'.format(self.name))

    def check_failed(self, msg):
        eh.check_failed('FAIL on job <{}>: {}'.format(self.name, msg))

    def run_ping(self):
        url = 'http://{}:{}/ping'.format(self.server_host, self.server_port)
        try:
            with closing(urllib2.urlopen(url, timeout=15)) as resp:
                if resp.getcode() != 200:
                    self.check_failed('server return not 200 Ok on /ping request {}'.format(resp.getcode()))
                logging.debug('ping response: {}'.format(resp.read()))
        except urllib2.HTTPError as e:
            self.check_failed('HTTPError in getting {} code {}'.format(url, e))

    def run_info(self):
        url = 'http://{}:{}/info'.format(self.server_host, self.server_port)
        try:
            with closing(urllib2.urlopen(url, timeout=15)) as resp:
                if resp.getcode() != 200:
                    self.check_failed('server return not 200 Ok on /info request {}'.format(resp.getcode()))
                result = resp.read()
                logging.debug('info response: {}'.format(result))

            # result validation
            info_xml = etree.fromstring(result)
            if info_xml.tag != 'Info':
                self.check_failed('bad info response on /info request (not Info tag at root):\n{}'.format(result))

            result_xml = info_xml.find('Result')
            if result_xml is None:
                self.check_failed('bad info response on /info request (not found Result tag):\n{}'.format(result))

            def check_tags(tags):
                for tag in tags:
                    if result_xml.find(tag) is None:
                        self.check_failed('bad info response on /info request (not found {} tag):\n{}'.format(
                            tag, result))

            check_tags(('Hostname', 'Version', 'Lingware', 'Config'))
        except urllib2.HTTPError as e:
            self.check_failed('HTTPError in getting {} code {}'.format(url, e))

    def run_unistat(self):
        url = 'http://{}:{}/unistat'.format(self.server_host, self.server_port)
        try:
            with closing(urllib2.urlopen(url, timeout=15)) as resp:
                if resp.getcode() != 200:
                    self.check_failed('server return not 200 Ok on /unistat request {}'.format(resp.getcode()))
                result = resp.read()
                logging.debug('unistat response: {}'.format(result))

            # result validation
            unistat = json.loads(result)
            http_2xx = None
            for cnt, val in unistat:
                if cnt == 'http_code_2xx_summ':
                    http_2xx = val
            if http_2xx is None:
                self.check_failed('unistat result not contain <http_code_2xx_summ> counter')
        except urllib2.HTTPError as e:
            self.check_failed('HTTPError in getting {} code {}'.format(url, e))

    def run_http_recognition(self):
        with open(self.audio_file, 'rb') as f:
            con = httplib.HTTPConnection(
                '{}:{}'.format(self.server_host, self.server_port),
                timeout=self.timeout if self.timeout else 3600,
            )
            con.putrequest('POST', '{}?uuid=1234'.format(self.uri_path))
            con.putheader('X-Yaldi-Mode', self.topic + '.' + self.language)
            con.putheader('X-Yaldi-App', 'app')
            con.putheader('Host', 'localhost')
            con.putheader('Content-Type', self.audio_mime)
            con.putheader('X-Yaldi-RequestId', '123')
            con.putheader('Transfer-Encoding', 'chunked')
            con.endheaders()

            logging.debug('sending chunked http recognition request')
            chunk = f.read(self.chunk_size)
            while chunk:
                con.send('{}\r\n'.format(hex(len(chunk))[2:]))
                con.send('{}\r\n'.format(chunk))
                chunk = f.read(self.chunk_size)
            con.send('0\r\n\r\n')

            resp = con.getresponse()

            if not resp:
                self.check_failed('has not response from asr-server')

            answer = resp.read()
            logging.debug('response (status={}):\n{}'.format(resp.status, answer))

            if resp.status != 200:
                self.check_failed('bad response status={}'.format(resp.status))

            if '<Status>OK</Status>' not in answer:
                logging.error(answer)
                self.check_failed('not found OK status in answer')

            expect_results = []
            if self.expect_result is not None:
                expect_results = self.expect_result.encode('utf-8').split('|||')
                recog = etree.fromstring(answer)
                if recog is None:
                    logging.error(answer)
                    self.check_failed('bad (empty?) XML in answer')
                if recog.tag != 'Recognition':
                    logging.error(answer)
                    logging.error('recog={}'.format(recog.tag))
                    self.check_failed('bad root tag in response (expect Recognition)')
                trs = recog.find('Transcriptions')
                if trs is None:
                    self.check_failed('not found Transcriptions tag in answer')
                tr = trs.find('Transcript')
                if tr is None:
                    self.check_failed('not found Transcript tag in answer')
                words = []
                for word in tr:
                    if word.tag == 'Word':
                        words.append(word.text.encode('utf-8'))
                utt = ' '.join(words)
                logging.debug('recognition_result="{}"'.format(utt))
                expect_versions = expect_results[0].split('|')
                if utt not in expect_versions:
                    self.check_failed('unexpected recogntion result="{}", expect one of="{}"'.format(
                        utt,
                        ' | '.join(expect_versions),
                    ))

    def run_protobuf_recognition(self):
        # import voicetech api
        # use local import as recommended https://clubs.at.yandex-team.ru/arcadia/16437
        from voicetech.library.proto_api.voiceproxy_pb2 import AdvancedASROptions
        from voicetech.library.proto_api.yaldi_pb2 import AddData as YaldiAddData
        from voicetech.library.proto_api.yaldi_pb2 import AddDataResponse as YaldiAddDataResponse
        from voicetech.library.proto_api.yaldi_pb2 import InitRequest as YaldiInitRequest
        from voicetech.library.proto_api.yaldi_pb2 import InitResponse as YaldiInitResponse
        from voicetech.library.proto_api.yaldi_pb2 import NormalizerOptions
        from voicetech.library.proto_api.yaldi_pb2 import OK as YaldiOK
        from voicetech.library.proto_api.yaldi_pb2 import Context as YaldiContext

        with ProtoTransport(self.server_host, self.server_port, timeout=self.timeout) as t:
            t.verbose = False
            logging.debug('connected, sending upgrade request')
            t.socket.send('GET {} HTTP/1.1\r\nHost: {}\r\nUpgrade: protobuf\r\n\r\n'.format(self.uri_path, self.server_host))
            resp101 = t.socket.recv(1024)
            logging.debug('receive:\n' + resp101)
            if 'HTTP/1.1 101 ' not in resp101:
                self.check_failed('expected "HTTP/1.1 101 " in response, but has={}'.format(resp101))
            logging.debug('upgraded to protobuf')
            norm_banlist = []
            if not self.antimat:
                norm_banlist.append('reverse_conversion.profanity')
            normalizer_options = NormalizerOptions(
                name='revnorm',
                lang=self.revnorm_language,
                banlist=norm_banlist,
            )
            context = None
            if self.context:
                context = [YaldiContext(**{str(k): v for k, v in c.items()}) for c in self.context]
            init_request = YaldiInitRequest(
                hostName='localhost',
                requestId='777',
                uuid='1234',
                device='unknown',
                coords='0, 0',
                topic=self.topic,
                lang=self.language,
                sampleRate='16000',
                punctuation=self.punctuation,
                normalizer_options=normalizer_options,
                experiments=json.dumps(self.experiments),
                advanced_options=AdvancedASROptions(
                    mime=self.audio_mime,
                    biometry_group='azaza',
                    utterance_silence=10,
                    chunk_process_limit=1000,
                    biometry='gender,age,language,group,emotion-0.0.1',
                    partial_results=True,
                    partial_update_period=self.partial_update_period,
                    spotter_back=self.spotter_back,
                    request_front=self.request_front,
                    spotter_validation=self.spotter_validation,
                    spotter_phrase=self.spotter_phrase,
                    enable_e2e_eou=self.enable_e2e_eou,
                    force_process_each_chunk=self.force_process_each_chunk,
                ),
                context=context,
            )
            if self.use_trash_talk_classifier:
                init_request.advanced_options.use_trash_talk_classifier = True

            if self.enable_degradation:
                init_request.advanced_options.degradation_mode = AdvancedASROptions.EDegradationMode.Enable

            init_request.advanced_options.always_fill_response = self.always_fill_response

            logging.debug('sending initial request:\n{}'.format(init_request))
            t.sendProtobuf(init_request)

            resp = t.recvProtobuf(YaldiInitResponse)
            logging.debug('connection is inited, response:\n{}'.format(resp))

            if resp.responseCode != 200:
                self.check_failed('bad responseCode={}'.format(resp.responseCode))
            logging.debug('response topic={}'.format(resp.topic))
            logging.debug('response version={}'.format(resp.topic_version))

            sendCount = 1
            recvCount = 0
            rawRecvCount = 0
            resultsnow = False
            logging.debug('begin transmit audio from file={}'.format(self.audio_file))
            start_time = time.time()
            if self.audio_spotter_file:
                for chunk in bytes_from_file(self.audio_spotter_file, self.chunk_size):
                    sendCount += 1
                    logging.debug('about to send {} bytes'.format(len(chunk)))

                    t.sendProtobuf(YaldiAddData(
                        audioData=chunk,
                        lastChunk=False,
                    ))

            for chunk in bytes_from_file(self.audio_file, self.chunk_size):
                sendCount += 1
                logging.debug('about to send {} bytes'.format(len(chunk)))

                t.sendProtobuf(YaldiAddData(
                    audioData=chunk,
                    lastChunk=False,
                ))

                logging.debug('chunk is sent')
                if self.chunk_duration:
                    logging.debug('suspend next chunk for {} ms'.format(self.chunk_duration))
                    time.sleep(self.chunk_duration / 1000.)

                if resultsnow:
                    response = t.recvProtobuf(YaldiAddDataResponse)
                    recvCount += response.messagesCount
                    logging('receive response: {}'.format(response))
                    result = response

            logging.debug('send last chunk')
            t.sendProtobuf(YaldiAddData(lastChunk=True))
            first_partial_results = []
            end_of_utt_results = []

            partial_result = None
            while recvCount < sendCount:
                response = t.recvProtobuf(YaldiAddDataResponse)
                rawRecvCount += 1
                logging.debug('response:\n{}'.format(response))
                if response.responseCode != YaldiOK:
                    if self.expect_error is not None and response.responseCode == self.expect_error:
                        logging.info('got expected error={}'.format(self.expect_error))
                        return
                    self.check_failed('bad[2] responseCode={}'.format(response.responseCode))
                recvCount += response.messagesCount
                result = response
                if result.endOfUtt and result.recognition:
                    end_of_utt_results.append(result)
                    partial_result = None
                elif not result.endOfUtt and result.recognition and result.recognition[0].HasField('normalized'):
                    partial_result = result.recognition[0].normalized
                    first_partial_results.append(partial_result)
                if result.recognition:
                    logging.debug('recognition result="{}"'.format(
                        ' '.join(word.value for word in result.recognition[0].words)))
                if result.metainfo:
                    logging.debug('META: minBeam={} maxBeam={} lang={} topic={} version={} load_timestamp={}'.format(
                        result.metainfo.minBeam, result.metainfo.maxBeam, result.metainfo.lang,
                        result.metainfo.topic, result.metainfo.version, result.metainfo.load_timestamp))
                for r in result.bioResult:
                    logging.debug('BIO: classname={} confidence={}'.format(r.classname, r.confidence))
                if self.request_front and result.endOfUtt:
                    # what happens to rawRecvCount?
                    # for now -- nothing important, we do not use always_fill_response
                    # (and, therefore, do not check rawRecvCount) on spotter requests
                    break

            logging.info('total recognition results={} duration={:.6f} sec'.format(
                len(end_of_utt_results), time.time() - start_time))

            expect_results = []
            if self.expect_result is not None:
                expect_results = self.expect_result.encode('utf-8').split('|||')
            expect_normalized_results = []
            if self.expect_normalized is not None:
                expect_normalized_results = self.expect_normalized.encode('utf-8').split('|||')
            expect_partial_results = []
            if self.expect_partial_normalized is not None:
                expect_partial_results = self.expect_partial_normalized.encode('utf-8').split('|||')

            if self.expect_raw_recv_count is not None:
                if self.expect_raw_recv_count != rawRecvCount:
                    self.check_failed('expected {} raw recv, got {}'.format(self.expect_raw_recv_count, rawRecvCount))

            recognition_result = None
            normalized_recognition_result = None
            i = 0  # utterance number

            first_eou_audio_duration = None
            for result in end_of_utt_results:
                normalized_recognition_result = None
                recognition_result = ' '.join(word.value for word in result.recognition[0].words)
                partial_norm_result = first_partial_results[i] if i < len(first_partial_results) else ''
                if first_eou_audio_duration is None:
                    first_eou_audio_duration = result.duration_processed_audio

                logging.info('recognition result[{}]="{}"'.format(i, recognition_result))
                if result.recognition[0].HasField('normalized'):
                    normalized_recognition_result = result.recognition[0].normalized
                    logging.info('normalized recognition result[{}]="{}"'.format(i, normalized_recognition_result))
                logging.info('partial normalized recognition result[{}]="{}"'.format(i, partial_norm_result))
                for r in result.bioResult:
                    logging.info('BIO recognition result: tag={} classname={} confidence={}'.format(r.tag, r.classname, r.confidence))

                if not self.spotter_validation:
                    degradation_enabled = json.loads(result.core_debug)["InDegradationMode"]
                    if degradation_enabled != self.enable_degradation:
                        self.check_failed('Degradation check. Expected {} got {}'.format(self.enable_degradation, degradation_enabled))

                if expect_results:
                    if i >= len(expect_results):
                        self.check_failed('recognize unexpected (extra) utterance: {}'.format(recognition_result))
                    expect_results_ = expect_results[i].split('|')
                    logging.debug('for utterance[{}] expect one of recognition results:\n{}'.format(
                        i, '\n'.join(expect_results_)))
                    if recognition_result not in expect_results_:
                        self.check_failed('recognition result not in expect_results')
                    logging.debug('utterance[{}] has expected recognition result'.format(i))
                if expect_normalized_results:
                    if i >= len(expect_normalized_results):
                        self.check_failed('recognize unexpected (extra) utterance (norm): {}'.format(normalized_recognition_result))
                    expect_normalized_results_ = expect_normalized_results[i].split('|')
                    logging.debug('for utterance[{}] expect one of normalized recognition results:\n{}'.format(
                        i, '\n'.join(expect_normalized_results_)))
                    if normalized_recognition_result not in expect_normalized_results_:
                        self.check_failed('recognition normalized result not in expect_normalized_results')
                    logging.debug('utterance[{}] has expected normalized recognition result'.format(i))
                if expect_partial_results:
                    skipCheck = False
                    if expect_partial_results[-1] == '...':
                        skipCheck = i >= len(expect_partial_results)
                    else:
                        print(expect_partial_results)
                        if i >= len(expect_normalized_results):
                            self.check_failed('recognize partial unexpected (extra) utterance (norm): {}'.format(normalized_recognition_result))
                    if not skipCheck:
                        expect_partial_results_ = expect_partial_results[i].split('|')
                        logging.debug('for partial utterance[{}] expect one of normalized recognition results:\n{}'.format(
                            i, '\n'.join(expect_partial_results_)))
                        if partial_norm_result not in expect_partial_results_:
                            self.check_failed('recognition partial normalized result={} not in expect_partial_results={}'.format(
                                partial_norm_result, expect_partial_results_))
                        logging.debug('partial utterance[{}] has expected normalized recognition result'.format(i))
                i += 1
            if i == 0 and (expect_results or expect_normalized_results or expect_partial_results):
                self.check_failed('not found recognition results for validation')

            if self.expect_eou_between:
                logging.debug('EOU is at {}'.format(first_eou_audio_duration))
                eou_min, eou_max = self.expect_eou_between
                eou_check_success = eou_min <= first_eou_audio_duration <= eou_max
                if not eou_check_success:
                    self.check_failed("EOU {} is not in range [{}, {}]".format(first_eou_audio_duration, eou_min, eou_max))
