# coding=UTF-8

import logging
import sys
from collections import defaultdict
from datetime import datetime

from sandbox.sandboxsdk.task import SandboxTask
from sandbox.sandboxsdk.svn import Arcadia
from sandbox.sandboxsdk import parameters, environments

BATCH_SIZE = 1500


class MrInputTableParam(parameters.ListRepeater, parameters.SandboxStringParameter):
    name = 'mr_input_table'
    description = 'Input MapReduce table(s) with queries'
    required = True
    group = 'MR'


class MrOutputTablePrefixParam(parameters.SandboxStringParameter):
    name = 'mr_output_prefix'
    description = 'Output MapReduce table name prefix'
    required = False
    group = 'MR'


class MrOutputTablePathParam(parameters.SandboxStringParameter):
    name = 'mr_output_path'
    description = 'Output MapReduce path'
    required = True
    group = 'MR'

    @classmethod
    def cast(cls, value):
        value = super(MrOutputTablePathParam, cls).cast(value)
        return value.rstrip('/')


class MrOutputTableNameParam(parameters.SandboxStringParameter):
    name = 'mr_output_table_name'
    description = 'Output MapReduce table name. If empty will use current date'
    required = False
    group = 'MR'


class MrServerParam(parameters.SandboxStringParameter):
    name = 'mr_server'
    description = 'MapReduce server'
    required = True
    group = 'MR'


class RegionParam(parameters.SandboxIntegerParameter):
    name = 'region'
    description = 'Region for all queries'
    required = True
    default_value = '225'
    group = 'Main'


class HostParam(parameters.SandboxStringParameter):
    name = 'host'
    description = 'Host to take serps from'
    required = True
    default_value = 'http://hamster.yandex.ru'
    group = 'Main'


class AuthParam(parameters.SandboxStringParameter):
    name = 'auth_vault'
    description = "Sandbox vault user (must contain 'scraper_user', 'scraper_oauth', 'yt_token' values)"
    required = True
    group = 'Auth'


class ScraperPageSizeParam(parameters.SandboxIntegerParameter):
    name = 'scraper_page_size_param'
    description = 'Scraper page size param'
    required = True
    default_value = 10
    group = 'Scraper'


class ScraperCgiParam(parameters.ListRepeater, parameters.SandboxStringParameter):
    name = 'scraper_cgi_params'
    description = 'Scraper cgi parameters'
    required = False
    group = 'Scraper'


PRESETS = {
    'web': 'yandex-web-profiled',
    'video': 'yandex-video-profiled',
}


class ScraperPresetParam(parameters.SandboxRadioParameter):
    name = 'scraper_preset_param'
    description = 'Scraper preset'
    required = False
    group = 'Scraper'
    default_value = 'web'
    choices = [
        ('web', PRESETS['web']),
        ('video', PRESETS['video']),
    ]


DEFAULT_SCRAPER_PRIORITY = 50


def download_serp(user, oauth_token, task_name, requests, cgi_params, region, host, page_size,
                  batch_size=500, preset=PRESETS['web']):
    from scraper.v2.task_builder import TaskBuilder
    from scraper.v2.errors import ScrapError

    builder = (TaskBuilder(oauth_token)
               .set_preset(preset)
               .set_meta(user, task_name)
               .set_batch_size(batch_size)
               .set_parameter('results-per-page', page_size)
               .set_host(host)
               .set_priority(DEFAULT_SCRAPER_PRIORITY))

    for request in requests:
        builder.add_query(request)

    grouped_cgi_params = defaultdict(list)
    for name, value in cgi_params:
        grouped_cgi_params[name].append(value)

    tasks = builder.set_parameters(
        {
            'region-id': region,
            'additional-cgi': grouped_cgi_params,
        }
    ).create_tasks()

    found = defaultdict(list)
    for task in tasks:
        task.run().wait()
        try:
            result = task.get()
        except ScrapError:
            logging.error('Failed to load scraper batch')
            raise

        for serp in result:
            found[serp['query']] += list(enumerate(serp['urls']))

    return dict(found)


def prepare_requests(request_file):
    requests = []
    with open(request_file, 'r') as f:
        for line in f:
            parsed_line = line.strip().split('\t')
            if not parsed_line:
                logging.warning('Encauntered a malformed line: {}'.format(line))
                continue
            requests.append({'query-text': parsed_line[0]})
    return requests


class WebQueriesFilter(SandboxTask):
    """
        Filtering web queries, SEARCHSPAM-10664
    """

    type = 'WEB_QUERIES_FILTER'

    input_parameters = [
        MrServerParam,
        MrInputTableParam,
        MrOutputTablePathParam,
        MrOutputTablePrefixParam,
        MrOutputTableNameParam,
        RegionParam,
        HostParam,
        AuthParam,
        ScraperPageSizeParam,
        ScraperCgiParam,
        ScraperPresetParam
    ]
    environment = (
        environments.PipEnvironment('yandex-yt', '0.8.11-0'),
        environments.PipEnvironment('yandex-yt-yson-bindings-skynet', '0.3.7.post1'),
    )

    local_dir = 'local_files'
    scraper_task_name = 'sandbox:web queries filter'

    @staticmethod
    def _cast_cgi_params(cgi_list):
        result = []
        if not cgi_list:
            return result
        for val in cgi_list:
            if not val.strip():
                continue
            result.append(val.split(u'=', 1))
        return result

    def _get_queries_from_mapreduce(self, yt_client):
        result = set()

        for table in self.ctx[MrInputTableParam.name]:
            for item in yt_client.read_table(table, format='yson', raw=False):
                query = item.get('query')
                if query:
                    result.add(query)
        return result

    def _send_docs_to_mapreduce(self, yt_client, found):
        urls = set()
        for query, docs in found.iteritems():
            for pos, url in docs:
                urls.add(url)
        result = []
        for url in urls:
            result.append({'url': url})

        table = '{}/{}{}'.format(
            self.ctx[MrOutputTablePathParam.name],
            self.ctx[MrOutputTablePrefixParam.name] if self.ctx[MrOutputTablePrefixParam.name] else '',
            self.ctx[MrOutputTableNameParam.name]
        )

        yt_client.write_table(table, result)
        yt_client.run_sort(table, sort_by='url')

    def on_enqueue(self):
        value = self.ctx[MrOutputTableNameParam.name]
        if not value or not value.strip():
            value = datetime.today().strftime('%Y-%m-%d')
        self.ctx[MrOutputTableNameParam.name] = value
        SandboxTask.on_enqueue(self)

    def on_execute(self):
        sys.path.append(Arcadia.get_arcadia_src_dir('arcadia:/arc/trunk/arcadia/yweb/antispam/util'))  # scraper

        from yt.wrapper import YtClient

        server = self.ctx[MrServerParam.name]
        token = self.get_vault_data(self.ctx[AuthParam.name], 'yt_token')
        yt_client = YtClient(server, token, {'yamr_mode': {'treat_unexisting_as_empty': True}})

        queries = self._get_queries_from_mapreduce(yt_client)

        scraper_user = self.get_vault_data(self.ctx[AuthParam.name], 'scraper_user')
        oauth_token = self.get_vault_data(self.ctx[AuthParam.name], 'scraper_oauth')

        requests = [{'query-text': query} for query in queries]

        cgi_params = WebQueriesFilter._cast_cgi_params(self.ctx[ScraperCgiParam.name])

        found = download_serp(
            scraper_user, oauth_token, self.scraper_task_name, requests, cgi_params,
            self.ctx[RegionParam.name],
            self.ctx[HostParam.name],
            self.ctx[ScraperPageSizeParam.name],
            batch_size=BATCH_SIZE,
            preset=self.ctx[ScraperPresetParam.name]
        )

        self._send_docs_to_mapreduce(yt_client, found)


__Task__ = WebQueriesFilter
