# -*- coding: utf-8 -*-

import base64
import hashlib
import urllib
import urllib2
import logging
import threading
import Queue
import time
import json
import urlparse

from sandbox.sandboxsdk import errors
from sandbox.sandboxsdk import parameters
from sandbox.sandboxsdk import task

from sandbox.projects import resource_types
from sandbox.projects.common import utils


class SaasRequest(object):
    def __init__(self, host_and_port, service_id, index_prefix, timeout, retry_count, text, storage_keys, region, reqid):
        self.host_and_port = host_and_port
        self.service_id = service_id
        self.index_prefix = index_prefix
        self.timeout = timeout
        self.retry_count = retry_count
        self.text = text
        self.storage_keys = storage_keys
        self.region = region
        self.reqid = reqid

    def calc_key(self):
        subkeys = []
        for key in self.storage_keys:
            if key == '0':
                m = hashlib.md5()
                m.update(self.text + '\t0')
                subkeys.append(m.hexdigest())
            else:
                m = hashlib.md5()
                m.update(self.text)
                subkeys.append(m.hexdigest() + '_' + key.encode('ascii'))
        return ','.join(subkeys)

    def calc_url(self):
        self.key = self.calc_key()
        params = (
            ('text', self.key),
            ('g', '0..10.1.-1.0.0.-1.rlv.0..0.0'),
            ('how', 'rlv'),
            ('numdoc', '10'),
            ('kps', self.index_prefix),
            ('service', self.service_id),
            ('timeout', str(self.timeout * 1000)),
            ('ms', 'proto'),
            ('relev', 'attr_limit=999999999'),
        )
        return 'http://' + self.host_and_port + '/?' + '&'.join(param + '=' + urllib.quote(value) for param, value in params)

    def calc_timeout(self):
        return (self.timeout+100) / 1000.0


def worker_proc(tasks_for_workers, completed_tasks):
    while True:
        saas_request = tasks_for_workers.get()
        if saas_request is None:
            break
        num_retries = 0
        url = saas_request.calc_url()
        while True:
            try:
                request = urllib2.urlopen(saas_request.calc_url(), timeout=saas_request.calc_timeout())
                saas_request.answer = request.read()
                break
            except IOError as e:
                if num_retries >= saas_request.retry_count:
                    logging.error('error processing request "{}" with region={}, URL={}: {}'.format(saas_request.text, saas_request.region, url, e))
                    saas_request.answer = None
                    break
                num_retries += 1
                logging.debug('error processing request "{}" with region={}, URL={}: {}, retrying #{}'.format(saas_request.text, saas_request.region, url, e, num_retries))
                time.sleep(1)

        completed_tasks.put(saas_request)


def write_one_result(output_counters, output_file, saas_request):
    failed = True
    if saas_request.answer:
        output_file.write('\t'.join((saas_request.text, str(saas_request.region), saas_request.key, base64.b64encode(saas_request.answer))) + '\n')
        failed = False
    else:
        output_file.write('\t'.join((saas_request.text, str(saas_request.region), saas_request.key, '[error]')) + '\n')
    with output_counters['lock']:
        output_counters['num_requests'] += 1
        if failed:
            output_counters['num_errors'] += 1


def writer_proc(output_counters, output_filename, completed_tasks):
    try:
        with open(output_filename, 'w') as output_file:
            active_reqid = 0
            requests_in_fly = {}
            while True:
                saas_request = completed_tasks.get()
                if saas_request is None:
                    break
                if saas_request.reqid == active_reqid:
                    write_one_result(output_counters, output_file, saas_request)
                    active_reqid += 1
                    while active_reqid in requests_in_fly:
                        write_one_result(output_counters, output_file, requests_in_fly[active_reqid])
                        del requests_in_fly[active_reqid]
                        active_reqid += 1
                else:
                    requests_in_fly[saas_request.reqid] = saas_request
    except Exception as e:
        logging.exception(e)


class LbGetSaasResponses(task.SandboxTask):
    """
        Собирает сырые ответы базы лингвобустинга в SaaS по списку запросов.
        SEARCH-2255
    """

    type = 'LB_GET_SAAS_RESPONSES'

    class InputQueries(parameters.ResourceSelector):
        name = 'input_queries'
        description = 'Input queries'
        resource_type = [
            resource_types.USERS_QUERIES,
            resource_types.PLAIN_TEXT_QUERIES,
            "WEB_MIDDLESEARCH_PLAIN_TEXT_QUERIES",
            resource_types.IMAGES_MIDDLESEARCH_PLAIN_TEXT_REQUESTS,
            resource_types.VIDEO_MIDDLESEARCH_PLAIN_TEXT_REQUESTS,
        ]
        required = True

    class NumQueriesParam(parameters.SandboxIntegerParameter):
        name = 'num_queries'
        description = 'Number of queries to use (0 = entire file)'
        default_value = 0

    class UseNoregRequests(parameters.SandboxBoolParameter):
        name = 'use_noreg_requests'
        description = 'Collect responses from old-style (nonregional) requests'
        default_value = True

    class UseRegionRequests(parameters.SandboxBoolParameter):
        name = 'use_region_requests'
        description = 'Collect responses from new-style requests'
        sub_fields = {'true': ['region_dispatch_config', 'geoa_c2p']}
        default_value = True

    class RegionDispatchConfig(parameters.SandboxStringParameter):
        name = 'region_dispatch_config'
        description = 'SaasDispatchConfig as json'
        default_value = '{"ru":["0","ru"],"ua":["0","ua"],"by":["0","by"],"kz":["0","kz"],"tr":["0","tr"]}'

    class GeoaC2P(parameters.ResourceSelector):
        name = 'geoa_c2p'
        description = 'geoa.c2p for lifting USERS_QUERIES to countries'
        resource_type = resource_types.GEOA_C2P

    class LingboostSaasHost(parameters.SandboxStringParameter):
        name = 'lingboost_saas_host'
        description = 'Host:port of lingboosting SaaS storage'
        default_value = 'saas-searchproxy-prestable.yandex.net:17000'

    class SaaSServiceId(parameters.SandboxStringParameter):
        name = 'saas_service_id'
        description = 'SaaS service ID'
        default_value = 'lingv_boosting'

    class SaaSIndexPrefix(parameters.SandboxStringParameter):
        name = 'saas_index_prefix'
        description = 'SaaS index prefix'
        default_value = '0'

    class TimeoutParam(parameters.SandboxIntegerParameter):
        name = 'timeout'
        description = 'Timeout in milliseconds for one request'
        default_value = 10000

    class RetryCountParam(parameters.SandboxIntegerParameter):
        name = 'retry_count'
        description = 'Number of retries for one request on errors'
        default_value = 2

    class MaxErrorsParam(parameters.SandboxStringParameter):
        name = 'max_errors'
        description = 'Maximum allowed errors (absolute or percent% of input size)'
        default_value = '5%'

    class ThreadCountParam(parameters.SandboxIntegerParameter):
        name = 'thread_count'
        description = 'Maximum number of parallel requests'
        default_value = 16

    input_parameters = (
        InputQueries,
        NumQueriesParam,
        UseNoregRequests,
        UseRegionRequests,
        RegionDispatchConfig,
        GeoaC2P,
        LingboostSaasHost,
        SaaSServiceId,
        SaaSIndexPrefix,
        TimeoutParam,
        RetryCountParam,
        MaxErrorsParam,
        ThreadCountParam,
    )

    def on_execute(self):
        host = utils.get_or_default(self.ctx, self.LingboostSaasHost)
        service_id = utils.get_or_default(self.ctx, self.SaaSServiceId)
        index_prefix = utils.get_or_default(self.ctx, self.SaaSIndexPrefix)
        timeout = utils.get_or_default(self.ctx, self.TimeoutParam)
        retry_count = utils.get_or_default(self.ctx, self.RetryCountParam)
        num_queries = utils.get_or_default(self.ctx, self.NumQueriesParam)
        thread_count = utils.get_or_default(self.ctx, self.ThreadCountParam)
        use_noreg_requests = utils.get_or_default(self.ctx, self.UseNoregRequests)
        use_region_requests = utils.get_or_default(self.ctx, self.UseRegionRequests)
        region_config = json.loads(utils.get_or_default(self.ctx, self.RegionDispatchConfig))

        input_queries_filename = self.sync_resource(self.ctx[self.InputQueries.name])
        geoa_c2p = self.load_geoa_c2p()

        max_errors = self.calc_max_errors(input_queries_filename, num_queries, int(use_noreg_requests) and int(use_region_requests))
        num_requests = 0

        output_attrs = {}
        if num_queries == 0 and use_noreg_requests and use_region_requests:
            output_attrs['full_responses_for'] = self.ctx[self.InputQueries.name]
        output_resource = self.create_resource(self.descr, 'saasdump.tsv', resource_types.LB_SAAS_RESPONSES, attributes=output_attrs)
        self.ctx['saasdump_id'] = output_resource.id

        tasks_for_workers = Queue.Queue(2 * thread_count)
        completed_tasks = Queue.Queue()
        output_counters = {'lock': threading.Lock(), 'num_requests': 0, 'num_errors': 0}
        with open(input_queries_filename, 'r') as input_file:
            writer_thread = threading.Thread(target=writer_proc, args=(output_counters, output_resource.path, completed_tasks))
            writer_thread.start()
            worker_threads = [None] * thread_count
            for i in xrange(thread_count):
                worker_threads[i] = threading.Thread(target=worker_proc, args=(tasks_for_workers, completed_tasks))
                worker_threads[i].start()
            try:
                for line in input_file:
                    text, region, relev_locale = self.parse_query_line(line, geoa_c2p)
                    storage_keys_list = []
                    if use_noreg_requests:
                        storage_keys_list.append(["0"])
                    if use_region_requests:
                        storage_keys = region_config.get(relev_locale, ["0"])
                        if storage_keys:
                            storage_keys_list.append(storage_keys)
                    if len(storage_keys_list) == 2 and storage_keys_list[0] == storage_keys_list[1]:
                        storage_keys_list.pop()
                    for storage_keys in storage_keys_list:
                        request = SaasRequest(
                            host_and_port=host,
                            service_id=service_id,
                            index_prefix=index_prefix,
                            timeout=timeout,
                            retry_count=retry_count,
                            text=text,
                            storage_keys=storage_keys,
                            region=region,
                            reqid=num_requests)
                        num_requests += 1
                        tasks_for_workers.put(request)
                    with output_counters['lock']:
                        if output_counters['num_errors'] > max_errors:
                            break
                    if num_queries and num_requests >= num_queries:
                        break
            finally:
                for _ in xrange(thread_count):
                    tasks_for_workers.put(None)
                for i in xrange(thread_count):
                    worker_threads[i].join()
                completed_tasks.put(None)
                writer_thread.join()

        logging.info('{} requests processed, {} errors'.format(output_counters['num_requests'], output_counters['num_errors']))
        if output_counters['num_errors'] > max_errors:
            raise errors.SandboxTaskFailureError('Too many errors. Aborted')

    def calc_max_errors(self, input_queries_filename, num_queries, multiplier):
        max_errors = utils.get_or_default(self.ctx, self.MaxErrorsParam)
        if max_errors.endswith('%'):
            total_lines = 0
            with open(input_queries_filename, 'r') as input_file:
                for line in input_file:
                    total_lines += 1
                    if num_queries and total_lines >= num_queries:
                        break
            return (float(max_errors[:-1]) * total_lines * multiplier) / 100
        else:
            return float(max_errors)

    def load_geoa_c2p(self):
        geoa_c2p_rsrc = self.ctx.get(self.GeoaC2P.name)
        if not geoa_c2p_rsrc:
            return {}
        geoa_c2p_filename = self.sync_resource(geoa_c2p_rsrc)
        result = {}
        for line in open(geoa_c2p_filename, 'r'):
            child, parent = line.rstrip('\r\n').split('\t')
            result[int(child)] = int(parent)
        return result

    def parse_query_line(self, line, geoa_c2p):
        line = line.rstrip('\r\n')
        if '\t' in line:
            # assume USERS_QUERIES
            text, region = line.split('\t')
            region = int(region)
            lifted_region = region
            countries = {225: 'ru', 187: 'ua', 159: 'kz', 149: 'by', 983: 'tr'}
            while lifted_region not in countries and lifted_region in geoa_c2p:
                lifted_region = geoa_c2p[lifted_region]
            relev_locale = countries.get(lifted_region, '')
        else:
            # assume PLAIN_TEXT_QUERIES or similar
            query = urlparse.urlparse(line).query.replace(';', '%3B')
            query_parsed = urlparse.parse_qs(query)
            relev = ';'.join(query_parsed.get('relev', [])).split(';')
            relev_dict = {x: y for x, y in [z.split('=', 1) for z in relev if '=' in z]}
            text = relev_dict.get('norm', '')  # requests like [site:yandex.ru] have no qnorm
            region = int(relev_dict.get('relevgeo', query_parsed.get('lr', ['-1'])[0]))
            relev_locale = relev_dict.get('relev_locale', '')
        return (text, region, relev_locale)


__Task__ = LbGetSaasResponses
