# -*- coding: utf-8 -*-

from __future__ import print_function, absolute_import, division

import re
import copy
import urlparse
import collections
import itertools
import logging

from nile.api.v1 import (
    Record,
    aggregators as na
)
from qb2.api.v1 import (
    extractors as qe,
    filters as qf
)


class TweakMapper(object):
    def __init__(self, config):
        super(TweakMapper, self).__init__()
        self.config = config
        self.referer_regex = re.compile(self.config['referer_regex'])

    def __call__(self, records):
        for record in records:
            context = record.to_dict()
            self.__tweak_referer(context)
            self.__tweak_eventtype(context)
            yield Record(
                key=context['key'],
                reqid=context['reqid'],
                showid=context['showid'],
                client=context['client'],
                subclient=context['subclient'],
                unixtime=context['unixtime'],
                distr_obj=context['distr_obj'],
                eventtype=context['eventtype'],
                region=context['region'],
                bannerid=context['bannerid'],
                score=context['score'],
                device=context['device'],
                device_id=context['device_id'],
                os=context['os'],
                browser=context['browser'],
                referer=context['referer'],
                product=context['product'],
                yandexuid=context['yandexuid'],
                uuid=context['uuid'],
                testids=context['testids'],
                type=context['type'],
                origin=context['origin'],
                past=context['past']
            )

    def __tweak_referer(self, context):
        try:
            parsed = urlparse.urlparse(context['referer'])
            hostname = parsed.hostname.strip('!@#$%^&*/')
            path = parsed.path.strip('!@#$%^&*/')
            while any([
                hostname.startswith(prefix)
                for prefix in self.config['irrelevant_preficies']
            ]):
                for prefix in self.config['irrelevant_preficies']:
                    if hostname.startswith(prefix):
                        hostname = hostname[len(prefix):]
            path = path.replace('yandsearch', 'search')
            path = path.replace('touchsearch', 'search')
            parts = path.split('/')
            if 'images/touch/search' in path or 'images/pad/search' in path:
                path = '/'.join(parts[:3])
            else:
	        path = parts[0]
            if '.html' in path:
                path = path[:-len('.html')]
            referer = '/'.join([hostname, path]).strip('/')
            if self.referer_regex.match(referer) is None:
                referer = 'other'
        except Exception:
            referer = None
        context['referer'] = referer or context['service']

    def __promolib_tweak_required(self, context):
        return (
            (context['distr_obj'] or '').startswith('promolib') and
            context['eventtype'] == 'install'
        )

    def __browser_zombie_tweak_required(self, context):
        is_zombie_host = any([
            (context['distr_obj'] or '').startswith(host)
            for host in self.config['zombie_hosts']
        ])
        return (
            is_zombie_host and
            context['product'] == 'browser_zombie' and
            context['eventtype'] in {'install', 'showlanding'}
        )

    def __set_opera_tweak_required(self, context):
        return (
            context['product'] == 'set_opera' and
            context['eventtype'] in 'install'
        )

    def __promo_tweak_required(self, context):
        referer = context['referer'] or ''
        return any([domain in referer for domain in self.config['promo_domains']])

    def __tweak_eventtype(self, context):
        if any([
            self.__promolib_tweak_required(context),
            self.__browser_zombie_tweak_required(context),
            self.__set_opera_tweak_required(context),
            self.__promo_tweak_required(context),
        ]):
            if context['eventtype'] in self.config['events_tweaks']:
                context['eventtype'] = self.config['events_tweaks'][context['eventtype']]


class FuseReducer(object):
    def __init__(self, config):
        super(FuseReducer, self).__init__()
        self.config = config

    def __call__(self, groups):
        # A small vocabulary
        # 1. context - an event and all its attributes like yandexuid, unixtime etc.
        # 2. aftermath - everything that happens after show, e.g. click, close, install etc.
        # 3. motive - the cause of the context, for now only show can be a motive
        for key, records in groups:
            categories = {
                'aftermath': [],
                'showids_aftermath': collections.defaultdict(list),
                'reqids_aftermath': collections.defaultdict(list),
                'showids_events': collections.defaultdict(set)
            }
            for record in self.__fuse_shows(records, categories):
                yield record
            for record in self.__fuse_aftermath(categories):
                yield record

    def __fuse_shows(self, records, categories):
        for record in records:
            context = record.to_dict()
            eventtype = context['eventtype']
            showid = context['showid']
            reqid = context['reqid']
            if eventtype != 'show':
                categories['aftermath'].append(context)
            elif reqid or showid:  # valid shows handling
                if reqid:
                    categories['reqids_aftermath'][reqid].append(context)
                if showid:
                    categories['showids_aftermath'][showid].append(context)
                if not context['past']:
                    yield Record(**context)

    def __fuse_aftermath(self, categories):
        for context in categories['aftermath']:
            if not context['past']:
                refined = self.__refine_context(
                    context,
                    self.__look_for_motives(context, categories),
                    categories
                )
                if refined:
                    yield Record(**refined)

    def __refine_context(self, context, motives, categories):
        is_real = self.__get_motives_filter(context)
        real_motives = [
            motive for motive in motives if is_real(motive)
        ]
        if real_motives:
            return self.__augment_context(
                context,
                max(real_motives, key=lambda motive: motive['unixtime']),
                categories
            )
        return None

    def __get_motives_filter(self, context):
        if context['bannerid']:
            return lambda motive: motive['bannerid'] == context['bannerid']
        elif context['distr_obj']:
            return lambda motive: motive['distr_obj'] == context['distr_obj']
        else:
            return lambda motive: True

    def __augment_context(self, context, motive, categories):
        augmented = copy.deepcopy(context)
        if context['eventtype'] not in categories['showids_events'][motive['showid']]:
            categories['showids_events'][motive['showid']].add(context['eventtype'])
            for key in self.config['augment_keys']:
                if not augmented[key]:
                    augmented[key] = motive[key]
            return augmented
        else:
            return None

    def __look_for_motives(self, context, categories):
        id_type = 'showid' if context['showid'] else 'reqid' if context['reqid'] else None
        motives = (
            categories['{}s_aftermath'.format(id_type)][context[id_type]]
            if id_type is not None else
            []
        )
        return [motive for motive in motives if motive is not None]


class SlicesMapper(object):
    def __init__(self, config):
        super(SlicesMapper, self).__init__()
        self.config = config

    def __call__(self, records):
        for record in records:
            context = record.to_dict()
            for slice_ in self.__make_slices_enumeration(context):
                yield Record(**slice_)

    def __make_slices_enumeration(self, context):
        attrs, combinations = self.__make_combinations(context)
        for combination in itertools.product(*combinations):
            slice_ = dict(zip(attrs, combination))
            slice_.update(dict(context['events_counts']))
            yield slice_

    def __make_combinations(self, context):
        attrs = self.config['columns']
        allextensions_products = set(self.config['allextensions_products'])
        variants = {
            attr: (context[attr], '_total_')
            for attr in attrs
        }
        if context['product'] == 'browser' and context['type']:
            variants['product'] += ('{}_{}'.format(context['product'], context['type']),)
        elif context['product'] in allextensions_products:
            variants['product'] += ('_allextensions_',)
        if context['device']:
            variants['os'] += ('{}_{}'.format(context['os'], context['device']),)
        return attrs, [variants[attr] for attr in attrs]


class AtomDistributionAggregator(object):
    def __init__(self, config):
        super(AtomDistributionAggregator, self).__init__()
        self.config = config
        self.logger = logging.getLogger(__name__)

    def aggregate(self, composite):
        self.logger.info('Aggregating composite data')
        atom_cube = self.__make_atom_cube(composite)
        totals = self.__make_totals(atom_cube)
        return {
            'atom_cube': atom_cube,
            'totals': totals
        }

    def __make_atom_cube(self, composite):
        self.logger.info('Building atom_cube')
        return composite.map(
            TweakMapper(self.config['tweak'])
        ).sort(
            'key', 'unixtime', 'yandexuid'
        ).groupby(
            'key'
        ).reduce(
            FuseReducer(self.config['fuse']),
            memory_limit=24288
        )

    def __make_totals(self, atom_cube):
        self.logger.info('Building totals')
        aliases = set(self.config['accumulate']['events_aliases'].values())
        return atom_cube.qb2(
            log='generic-yson-log',
            fields=[
                qe.all(exclude=('testids', 'eventtype')),
                qe.log_field('testids').hide(),
                qe.log_field('eventtype').rename('raw_eventtype').hide(),
                qe.custom('test_ids', lambda testids: testids or [None], 'testids').hide(),
                qe.unfold('testid', sequence='test_ids'),
                qe.custom(
                    'eventtype',
                    lambda raw_eventtype: self.config['accumulate']['events_aliases'].get(raw_eventtype),
                    'raw_eventtype'
                )
            ],
            filters=[
                qf.defined('eventtype')
            ]
        ).groupby(
            'product', 'referer', 'region', 'browser', 'os',
            'testid', 'distr_obj', 'type', 'device'
        ).aggregate(
            events_counts=na.histogram('eventtype')
        ).map(
            SlicesMapper(self.config['accumulate'])
        ).groupby(
           *self.config['accumulate']['columns']
        ).aggregate(
            **{
                '{}s'.format(eventtype): na.sum(eventtype)
                for eventtype in aliases
            }
        ).qb2(
            log='generic-yson-log',
            fields=[qe.all()] + [qe.log_field('{}s'.format(eventtype), default=0) for eventtype in aliases]
        )
