# coding=utf-8
import re
from collections import defaultdict
from logging import getLogger
from typing import List, Dict, Tuple

from django.conf import settings

from travel.avia.admin.lib.yt_helpers import yt_client_fabric

if settings.configured:
    from travel.avia.library.python.avia_data.models import NationalVersion
    from travel.avia.library.python.common.models.partner import Partner


class RedirectProvider(object):
    def __init__(self, yt_fabric, logger):
        self._yt_fabric = yt_fabric
        self._logger = logger

    def collect(self, sources):
        # type: (List[str]) -> Dict[str, Dict[Tuple[str, str], Dict[str, int]]]
        """
        :param sources:
        :return: national_version -> (route -> (partner -> click))
        """
        nv_to_clicks_by_route = defaultdict(dict)
        with self._yt_fabric.create().TempTable() as temp_table:
            self._collect(sources, temp_table)

            for r in self._yt_fabric.create().read_table(temp_table):
                clicks_by_route = nv_to_clicks_by_route[r['national_version']]
                if len(r['clicks']):
                    clicks_by_route[(r['fromId'], r['toId'])] = r['clicks']
        nv_to_clicks_by_route = dict(nv_to_clicks_by_route)

        for nv in nv_to_clicks_by_route:
            self._logger.info(
                'count record [%d] by national version',
                len(nv_to_clicks_by_route[nv])
            )

        all_nv = set(NationalVersion.objects.values_list('code', flat=True))
        return {
            nv: clicks for nv, clicks in nv_to_clicks_by_route.iteritems()
            if nv in all_nv
        }

    def _collect(self, sources, output):
        is_settlement = re.compile(r'^c\d+$')
        is_station = re.compile(r'^s\d+$')

        def _is_point(point_key):
            if not point_key or not isinstance(point_key, basestring):
                return False
            return is_station.match(point_key) or is_settlement.match(point_key)

        def _calculate_redirect_frequency_by_route_and_partner(key, items):
            if (not _is_point(key['fromId']) or
               not _is_point(key['toId'])):
                return
            answer = dict(key)
            clicks = defaultdict(int)

            for item in items:
                try:
                    billing_order_id = item['billing_order_id']
                    partner_id = billing_order_id_to_partner_id[
                        billing_order_id
                    ]
                    clicks[partner_id] += 1
                except Exception:
                    pass

            answer['clicks'] = dict(clicks)

            yield answer

        billing_order_id_to_partner_id = {
            str(p['billing_order_id']): str(p['id'])
            for p in Partner.objects.values('billing_order_id', 'id')
        }

        self._logger.info('start sort source tables')
        self._yt_fabric.create().run_sort(
            sources,
            output,
            sort_by=['fromId', 'toId', 'national_version']
        )
        self._logger.info('finish sort source tables')

        self._logger.info('start reduce')
        self._yt_fabric.create().run_reduce(
            _calculate_redirect_frequency_by_route_and_partner,
            output,
            output,
            reduce_by=['fromId', 'toId', 'national_version'],
        )
        self._logger.info('finish reduce')


redirect_provider = RedirectProvider(
    yt_fabric=yt_client_fabric,
    logger=getLogger(__name__)
)
