# -*- encoding: utf-8 -*-
import logging
from collections import defaultdict

from contextlib2 import closing
from datetime import timedelta, date

from pathlib2 import Path  # noqa

from travel.avia.dump_data.lib.model_classes.base_model import BaseModel
from travel.avia.dump_data.lib.mysql_connector import MysqlConnector  # noqa
from travel.library.python.dicts import file_util


COUNT_DAYS = 30


StationQuery = """
    select id, settlement_id
    from www_station
    where settlement_id is not null
"""

Station2SettlementQuery = """
    select station_id, settlement_id
    from www_station2settlement
"""

RouteSheduleQuery = """
    select `station_from_id`, `station_to_id`, count(distinct `route_number`) as `total`
    from `www_routeschedule`
    where (substring(run_mask, %s, 1) = "1")
    group by `station_from_id`, `station_to_id`
"""


def collect(cursor, day):
    cursor.execute(RouteSheduleQuery, (day_index(day), ))
    result = list(cursor)
    logging.info('Fetched schedule for %s: %s rows', day, len(result))
    return result


def day_index(date_):
    return (date_.month - 1) * 31 + (date_.day - 1)


class RouteCountModel(BaseModel):
    def __init__(self, name, connector, proto_model):
        self._name = name
        self.connector = connector  # type: MysqlConnector
        self.proto_model = proto_model

        self._today = date.today() + timedelta(days=1)
        self._related_keys = {}
        self._route_counts = defaultdict(lambda: defaultdict(int))

    @property
    def name(self):
        return self._name

    def dump_into_directory(self, directory):
        # type: (Path) -> None
        file_name = directory / self.get_output_file_name()

        with closing(self.connector.get_connection()) as connection:
            with closing(connection.cursor()) as cursor:
                self._prepare_related_keys(cursor)
                self._fetch_route_count_from_mysql(cursor)

        with open(str(file_name), 'wb') as outfile:
            self._dump_into_file(outfile)

        logging.info('Write %s', str(file_name))

    def _prepare_related_keys(self, cursor):
        cursor.execute(StationQuery)
        for row in cursor:
            self._add_related_key(row['id'], row['settlement_id'])

        cursor.execute(Station2SettlementQuery)
        for row in cursor:
            self._add_related_key(row['station_id'], row['settlement_id'])
        logging.info('Collected related keys for %d stations', len(self._related_keys))

    def _add_related_key(self, starion, settlement):
        settlement_point_key = 'c' + str(settlement)
        self._related_keys.setdefault(starion, set()).add(settlement_point_key)

    def _fetch_route_count_from_mysql(self, cursor):
        for delta in range(COUNT_DAYS):
            day = self._today + timedelta(days=delta)
            aggregates = collect(cursor, day)

            for aggregate in aggregates:
                for from_key in self._related_point_keys(aggregate['station_from_id']):
                    for to_key in self._related_point_keys(aggregate['station_to_id']):
                        self._route_counts[from_key, to_key][day] += aggregate['total']

        logging.info('Got route counts for %d point pairs', len(self._route_counts))

    def _related_point_keys(self, station_id):
        return {
            's' + str(station_id),
        }.union(self._related_keys.get(station_id, []))

    def _dump_into_file(self, file):
        for (point_from, point_to), totals in self._route_counts.iteritems():
            proto = self.proto_model()
            proto.PointFrom = point_from
            proto.PointTo = point_to
            proto.Count = max(totals.itervalues())
            file_util.write_binary_string(file, proto.SerializeToString())
