# coding: utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import requests

from collections import Counter
from sqlalchemy import func, and_, or_

from travel.rasp.bus.db import session_scope
from travel.rasp.bus.db.models.matching import PointMatching, PointType
from travel.rasp.bus.scripts.automatcher.scenarios.base import BaseMatcher
from travel.rasp.bus.scripts.automatcher.policy import TypePolicy

log = logging.getLogger(__name__)


def prepare_by_code(stations, config):
    log.info('prepare station codes')
    stations_dict = {}
    for station in stations:
        stations_dict[station['code']] = 's{}'.format(station['station_id'])
    return stations_dict


def prepare_by_title_etraffic(stations, config):
    log.info('prepare etraffic station titles')
    etraffic_cfg = config.get('etraffic')
    if not etraffic_cfg:
        raise ValueError("etraffic prepare called without valid config. cannot get 'etraffic' config from: {}".
                         format(config))
    supplier_id = etraffic_cfg['buses_id']
    stations_dict = {}

    title_counts = Counter(title for station in stations for title in station['titles'])
    titles_list = [title for title, count in title_counts.items() if count == 1]

    no_okato = 'empty_okato'
    with session_scope() as session:
        okato_extra_info = func.coalesce(PointMatching.extra_info, no_okato).label('extra_info')

        base_titles = session.query(PointMatching.title, okato_extra_info).filter(
            PointMatching.supplier_id == supplier_id,
            PointMatching.type == PointType.STATION,
            PointMatching.title.in_(titles_list)
        ).group_by(PointMatching.title, PointMatching.extra_info).subquery()

        stations_ = session.query(PointMatching.title,
                                  okato_extra_info,
                                  func.count(PointMatching.id).label('stations_num')).filter(
            PointMatching.supplier_id == supplier_id,
            PointMatching.type == PointType.STATION,
            PointMatching.title.in_(titles_list)
        ).group_by(PointMatching.title, PointMatching.extra_info).subquery()

        regions_group = session.query(PointMatching.title, okato_extra_info, PointMatching.region).filter(
            PointMatching.supplier_id == supplier_id,
            PointMatching.type == PointType.STATION,
            PointMatching.title.in_(titles_list)
        ).group_by(PointMatching.title, PointMatching.extra_info, PointMatching.region).subquery()

        regions = session.query(regions_group.c.title,
                                func.coalesce(regions_group.c.extra_info, no_okato).label('extra_info'),
                                func.count().label('regions_num')).group_by(
            regions_group.c.title, regions_group.c.extra_info
        ).subquery()

        notempty = session.query(regions_group.c.title,
                                 func.coalesce(regions_group.c.extra_info, no_okato).label('extra_info'),
                                 func.count(regions_group.c.region).label('hasregion')).group_by(
            regions_group.c.title, regions_group.c.extra_info
        ).subquery()

        almost_valid_titles = session.query(base_titles.c.title).group_by(base_titles.c.title). \
            having(func.count() == 1).subquery()

        groups = session.query(base_titles.c.title, base_titles.c.extra_info, stations_.c.stations_num,
                               regions.c.regions_num, notempty.c.hasregion). \
            outerjoin(stations_, and_(base_titles.c.title == stations_.c.title,
                                      base_titles.c.extra_info == stations_.c.extra_info)). \
            outerjoin(regions, and_(base_titles.c.title == regions.c.title,
                                    base_titles.c.extra_info == regions.c.extra_info)). \
            outerjoin(notempty, and_(base_titles.c.title == notempty.c.title,
                                     base_titles.c.extra_info == notempty.c.extra_info)).subquery()

        valid_title_rows = session.query(groups.c.title).filter(and_(groups.c.title.in_(almost_valid_titles),
                                                                     or_(groups.c.extra_info != no_okato,
                                                                         and_(groups.c.extra_info == no_okato,
                                                                              or_(
                                                                                  and_(groups.c.regions_num == 1,
                                                                                       groups.c.hasregion > 0),
                                                                                  and_(groups.c.stations_num == 1,
                                                                                       groups.c.hasregion == 0)
                                                                              ))))).all()

    log.info("found etraffic titles: %d", len(valid_title_rows))
    valid_titles = {title for (title,) in valid_title_rows}
    for station in stations:
        for title in station['titles']:
            if title in valid_titles:
                stations_dict[title] = 's{}'.format(station['station_id'])

    return stations_dict


def match_by_code(stations_dict, point):
    if point.supplier_point_id in stations_dict:
        return stations_dict[point.supplier_point_id]
    return None


def match_by_title(stations_dict, point):
    if point.title in stations_dict:
        return stations_dict[point.title]
    return None


PREPARERS = {
    'code': prepare_by_code,
    'etraffic_title': prepare_by_title_etraffic,
}

MATCHERS = {
    'code': match_by_code,
    'etraffic_title': match_by_title,
}


class MatchByRaspData(BaseMatcher):
    name = 'rasp_data'
    point_type_policy = TypePolicy.TYPE_STATION
    RASP_DATA_URL = 'https://s3.mds.yandex.net/rasp-bucket/rasp-export/bus_station_codes.json'

    rasp_data = None
    supplier_matchings = None
    config = None

    def __init__(self, **params):
        super(MatchByRaspData, self).__init__(**params)

        log.info('prepare scenario: %s', self.name)
        config = self.get_config(params)

        response = requests.get(self.RASP_DATA_URL)
        log.info('got rasp stations JSON with size: %d', len(response.content))
        self.rasp_data = response.json()

        rasp_to_buses_ids = {cfg['rasp_id']: code for code, cfg in config.items()}

        for supplier, cfg in config.items():
            supplier_id = self.get_supplier_id(supplier)
            cfg['buses_id'] = supplier_id
        self.config = config

        self.supplier_stations = {}
        for supplier in self.rasp_data['suppliers']:
            rasp_id = int(supplier['id'])
            if rasp_id in rasp_to_buses_ids.keys():
                supplier_code = rasp_to_buses_ids[rasp_id]
                supplier_id = self.config[supplier_code]['buses_id']
                stations = supplier['station_codes']
                log.info('found "%s" by id %d with %d rasp stations', supplier_code, rasp_id, len(stations))

                prepared_stations = self._prepare_stations(stations, config[supplier_code])
                self.supplier_stations[supplier_id] = {
                    'stations': prepared_stations,
                    'match_by': config[supplier_code]['match_by']
                }

    def _run(self, point):
        # TODO report if station titles group was used twice and more, also usual title matching
        if point.supplier_id in self.supplier_stations:
            supplier_data = self.supplier_stations[point.supplier_id]
            matcher = MATCHERS[supplier_data['match_by']]
            point_key = matcher(supplier_data['stations'], point)
            if point_key:
                return True, point_key
        return False, None

    def _prepare_stations(self, stations, cfg):
        match_by = cfg['match_by']
        if match_by not in PREPARERS:
            raise ValueError('unknown matching rule: {}'.format(match_by))
        stations_dict = PREPARERS[match_by](stations, self.config)
        prefix = cfg.get('prefix')
        if prefix:
            stations_dict = {'{}{}'.format(prefix, key): value for key, value in stations_dict.items()}
        return stations_dict
