import time
import json
import logging
import multiprocessing

import sandbox.sdk2 as sdk2
import sandbox.sandboxsdk.environments as environments
import sandbox.common.config as config

from sandbox.projects.common import file_utils as fu
from sandbox.projects.common import solomon

from sandbox.projects.MarketModelWizard import GetDocidsFromXmlSearchResultTsv, GetDocidsFromXmlSearchInputTsv

from requests_handler import get_request_handler

SOLOMON_URL = 'http://solomon.yandex.net/push/json'
JSON_PATH = 'yandexsearch.response.results.grouping.*.group.*.doc.*.id'


def grep_data(params):
    data = params['data']
    origin = params['origin']
    rps = params['rps']
    reqinfo = params['reqinfo']
    numdoc = params['numdoc']
    shared_rps = params['shared_rps']
    batch_size = params['batch_size']
    results = list()
    handler = get_request_handler(rps, 1)(
        origin,
        json_dump=JSON_PATH,
        reqinfo=reqinfo,
        groupby='attr=d.mode=deep.groups-on-page={numdoc}'.format(numdoc=numdoc),
    )
    start_time = time.time()
    batch = 0
    previous_batch_time = time.time()
    previous_batch_count = 0
    for _id, geoid, text in data:
        results.append(
            handler.get_docids_retry(
                text, geoid,
                retry=3, sleep=1, sleep_retry=2,
            )
        )
        batch += 1
        if batch >= batch_size:
            batch = 0
            current_time = time.time()
            current_count = handler.get_count

            shared_rps.value = float(current_count - previous_batch_count) / (current_time - previous_batch_time)

            previous_batch_time = current_time
            previous_batch_count = current_count

    return (results, handler.get_count / (time.time() - start_time))


class GetDocidsFromXmlSearch(sdk2.Task):
    """Grep top docid for text in specific geoid"""

    class Parameters(sdk2.Task.Parameters):
        input_tsv = sdk2.parameters.Resource(
            'Input tsv table (id, geoid, text)',
            resource_type=GetDocidsFromXmlSearchInputTsv,
        )

        with sdk2.parameters.RadioGroup('XML origin') as xml_origin:
            xml_origin.values['https://hamster.yandex.ru/search/xml'] = xml_origin.Value(value='Hamster', default=True)
            xml_origin.values['https://yandex.ru/search/xml'] = xml_origin.Value(value='Yandex')

        numdoc = sdk2.parameters.Integer('Top documents count', default=10)

        reqinfo = sdk2.parameters.String(
            'Your id from https://wiki.yandex-team.ru/jandeksxml/internal-users/',
            required=True,
        )

        thread_count = sdk2.parameters.Integer('Threads count', default=1)

        rps_per_thread = sdk2.parameters.Integer('RPS per thread', required=True, default=2)

        url_shift = sdk2.parameters.Integer('First row index', default=None)

        url_limit = sdk2.parameters.Integer('Rows count', default=None)

        send_solomon_metrics = sdk2.parameters.Bool('Send rps to solomon', default=False)

        with send_solomon_metrics.value[True]:
            with sdk2.parameters.Group('Solomon push-api parameters') as saas_parameters:
                solomon_project = sdk2.parameters.String(
                    'Project',
                    required=True,
                )

                solomon_cluster = sdk2.parameters.String(
                    'Cluster',
                    required=True,
                )

                solomon_service = sdk2.parameters.String(
                    'Service',
                    required=True,
                )

                solomon_sensor = sdk2.parameters.String(
                    'Sensor',
                    required=True,
                )

                solomon_hostname = sdk2.parameters.String(
                    'Hostname, will be sandbox_{{hostname}}',
                    required=True,
                )

    class Requirements(sdk2.Task.Requirements):
        environments = (
            environments.PipEnvironment('solomon'),
        )

    def send_solomon(self, rps):
        if not self.Parameters.send_solomon_metrics:
            return

        commonLabels = {
            'host': self.Context.host_hash,
            'project': self.Parameters.solomon_project,
            'cluster': self.Parameters.solomon_cluster,
            'service': self.Parameters.solomon_service,
        }
        sensors = [
            {
                'labels': {'sensor': self.Parameters.solomon_sensor, },
                'ts': int(time.time()),
                'value': int(round(rps)),
            }
        ]

        solomon.upload_to_solomon(commonLabels, sensors)

    def on_execute(self):
        if self.Parameters.solomon_hostname:
            self.Context.host_hash = 'sandbox_{}'.format(self.Parameters.solomon_hostname)
        else:
            self.Context.host_hash = 'sandbox_{}'.format(config.Registry().this.id)

        cgi_params = dict(
            json_dump=JSON_PATH,
            reqinfo=self.Parameters.reqinfo,
            groupby='attr=d.mode=deep.groups-on-page={}'.format(self.Parameters.numdoc),
        )

        logging.info('cgi parameters: {}'.format(json.dumps(cgi_params, indent=4)))

        input_tsv_data = sdk2.ResourceData(self.Parameters.input_tsv)

        thread_count = int(self.Parameters.thread_count)

        manager = multiprocessing.Manager()
        shared_values = [manager.Value('f', 0.0) for _ in range(thread_count)]
        params = [
            dict(
                data=list(),
                origin=str(self.Parameters.xml_origin),
                rps=int(self.Parameters.rps_per_thread),
                reqinfo=str(self.Parameters.reqinfo),
                numdoc=int(self.Parameters.numdoc),
                batch_size=120,
                shared_rps=shared_values[i],
            ) for i in xrange(thread_count)]

        count = 0

        url_shift = int(self.Parameters.url_shift) if self.Parameters.url_shift is not None else 0
        url_limit = int(self.Parameters.url_limit) if self.Parameters.url_limit is not None else None

        for i, line in enumerate(fu.read_line_by_line(str(input_tsv_data.path))):
            if url_shift is not None and i < url_shift:
                continue
            if url_limit is not None and i >= url_shift + url_limit:
                break
            params[i % thread_count]['data'].append(line.split('\t'))
            count += 1

        pool = multiprocessing.Pool(processes=thread_count)

        logging.info('start grep {count} queries from {thread_count} threads'.format(
            count=count,
            thread_count=thread_count)
        )

        logging.info('grep from {origin} with rps {rps}. estimated time {time}s'.format(
            origin=self.Parameters.xml_origin,
            rps=thread_count * int(self.Parameters.rps_per_thread),
            time=float(count) / (thread_count * int(self.Parameters.rps_per_thread)))
        )

        self.send_solomon(0)
        async_map = pool.map_async(grep_data, params)

        async_map.wait(10)

        while True:
            async_map.wait(60)
            if async_map.ready():
                self.send_solomon(0)
                break
            rps = sum((a.value for a in shared_values))
            self.send_solomon(rps)
            logging.info(
                'RPS: {}'.format(
                    rps
                )
            )

        tuple_result = async_map.get()

        result = [x[0] for x in tuple_result]
        mean_rps = [x[1] for x in tuple_result]

        logging.info('finish grep. mean RPS: {}. sum mean RPS: {}.'.format(mean_rps, sum(mean_rps)))

        result_lines = list()

        none_count = 0
        for thread_data, thread_result in zip(params, result):
            for (_id, geoid, text), docids in zip(thread_data['data'], thread_result):
                result_lines.append(
                    '{id}\t{geoid}\t{text}\t{docids}'.format(
                        id=_id,
                        geoid=geoid,
                        text=text,
                        docids=('\t'.join(docids) if docids is not None else ''),
                    )
                )
                if docids is None:
                    none_count += 1

        logging.info('none count: {}'.format(none_count))

        result_resource_data = sdk2.ResourceData(
            GetDocidsFromXmlSearchResultTsv(
                self,
                'result tsv from task {task_id} resource {resource_id}'.format(
                    task_id=self.id, resource_id=self.Parameters.input_tsv.id,
                ),
                'result.tsv',
                input_resource_id=self.Parameters.input_tsv.id,
                line_shift=url_shift,
                line_count=count,
            )
        )
        fu.write_lines(result_resource_data.path, result_lines)

        result_resource_data.ready()
