# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging

from itertools import groupby
from sqlalchemy import func

from travel.rasp.bus.db import session_scope
from travel.rasp.bus.db.models.matching import PointMatching, PointType

from travel.rasp.bus.scripts.automatcher.scenarios.base import BaseMatcher
from travel.rasp.bus.scripts.automatcher.policy import TypePolicy


log = logging.getLogger(__name__)


# TODO think about stations without okato, match by region group
class EtrafficTwinStations(BaseMatcher):
    name = 'etraffic_twin_stations'
    point_type_policy = TypePolicy.TYPE_STATION
    supplier = 'etraffic'
    twin_stations = {}
    bad_twin_stations = {}
    okato_blacklist = ['-', '0']

    def __init__(self, **params):
        super(EtrafficTwinStations, self).__init__(**params)

        log.info('prepare scenario: %s', self.name)
        supplier_id = self.get_scenario_supplier_id()
        with session_scope() as session:
            twins_groups = session.query(PointMatching.title, PointMatching.extra_info, PointMatching.type). \
                filter(PointMatching.supplier_id == supplier_id, PointMatching.type == PointType.STATION,
                       PointMatching.extra_info.notin_(self.okato_blacklist),
                       PointMatching.extra_info.isnot(None)). \
                group_by(PointMatching.title, PointMatching.extra_info, PointMatching.type).\
                having(func.count() > 1).subquery()
            twin_stations = session.query(PointMatching.title,
                                          PointMatching.extra_info,
                                          PointMatching.point_key,
                                          PointMatching.supplier_point_id,
                                          ).filter(PointMatching.extra_info.notin_(self.okato_blacklist),
                                                   PointMatching.extra_info.isnot(None),
                                                   PointMatching.supplier_id == supplier_id,
                                                   PointMatching.extra_info == twins_groups.c.extra_info,
                                                   PointMatching.title == twins_groups.c.title,
                                                   PointMatching.type == twins_groups.c.type). \
                order_by(
                PointMatching.extra_info, PointMatching.title
            ).all()
        """
        ищем группы, где точки сматчены на одну станцию, такие мы сможем использовать для матчинга остальных точек
        в группе. группы, где есть  матчинг на две и более различных станции, либо нет вообще, нам не подходят,
        но о них мы можем сообщить
        """
        for key, group in groupby(twin_stations, self._twin_key):
            twins = list(group)
            unique_point_keys = {point.point_key for point in twins} - {None}
            if len(unique_point_keys) == 1:
                self.twin_stations[key] = unique_point_keys.pop()
            else:
                self.bad_twin_stations[key] = twins
        log.info('before. prepared %d groups with 1 unique key, and %d groups with 0 or greater than 2 keys',
                 len(self.twin_stations), len(self.bad_twin_stations))
        self.report.append(('twins key', 'title', 'extra_info', 'point_key', 'supplier_point_id'))
        for key, bad_group in self.bad_twin_stations.items():
            for bad_row in bad_group:
                self.report.append((', '.join(key), ) + tuple(bad_row))

    def _run(self, point):
        twins_key = self._twin_key(point)
        if twins_key in self.twin_stations.keys():
            return True, self.twin_stations[twins_key]
        return False, None

    @staticmethod
    def _twin_key(point):
        return point.extra_info, point.title
