# -*- coding: utf-8 -*-
from __future__ import absolute_import

from collections import defaultdict
from django.db import connection
from typing import Dict, Tuple  # noqa


class ReviewStatRepository(object):
    def __init__(self, conn):
        # type: (connection) -> None
        self._connection = conn
        self._count_review_for_airline = {}  # type: Dict[int, int]
        self._count_review_for_flight_number = {}  # type: Dict[Tuple[int, str], int]

    def pre_cache(self):
        self._count_review_for_airline = defaultdict(int)
        self._count_review_for_flight_number = {}

        for row in self._load_flight_stat():
            airline_id = int(row['airline_id'])

            self._count_review_for_airline[airline_id] += row['amount']

            if row['flight_number']:
                flight_number = row['flight_number'].split(' ')[-1]
                self._count_review_for_flight_number[
                    (airline_id, flight_number)
                ] = row['amount']

        self._count_review_for_airline = dict(self._count_review_for_airline)

    def _load_flight_stat(self):
        with self._connection.cursor() as cursor:
            cursor.execute("""
                SELECT
                  avia_flightreview.airline_id,
                  avia_flightnumber.flight_number,
                  count(avia_flightreview.id) as amount
                FROM
                  avia_flightreview
                LEFT JOIN
                  avia_flightreview_flight_numbers ON avia_flightreview.id = avia_flightreview_flight_numbers.flightreview_id
                LEFT JOIN
                  avia_flightnumber ON avia_flightreview_flight_numbers.flightnumber_id = avia_flightnumber.id
                WHERE
                  avia_flightreview.airline_id is not null
                  and avia_flightreview.enable_show=1
                GROUP BY
                  avia_flightreview.airline_id,
                  avia_flightnumber.flight_number
            """)

            columns = [col[0] for col in cursor.description]

            return [
                dict(zip(columns, row))
                for row in cursor.fetchall()
            ]

    def get(self, airline_id, flight_number):
        # type: (int, str) -> dict

        return {
            'total': self._count_review_for_flight_number.get(
                (airline_id, flight_number), 0
            ),
            'airline_total': self._count_review_for_airline.get(airline_id, 0),
        }


review_stat_repository = ReviewStatRepository(
    conn=connection,
)
