from collections import Counter
from datetime import datetime, timezone
from operator import itemgetter

import psycopg2
from psycopg2.extras import RealDictCursor

from travel.rasp.bus.toloka_matching.src.config import Config
from travel.rasp.bus.toloka_matching.src.utils.itertools import group_by
from travel.rasp.bus.toloka_matching.src.utils.sql import insert_many


def _fetch_unhandled_results(cur, overlap):
        cur.execute(
            """
            SELECT
                r.task_id,
                ep.supplier,
                ep.supplier_id,
                ep.type,
                r.matching
            FROM (
                SELECT
                    r.task_id,
                    array_agg(ARRAY[p.code, pp.code]) as matching
                FROM matching.toloka_results as r
                LEFT JOIN ims.point as p ON r.code = p.code
                LEFT JOIN ims.point as pp ON pp.id = p.parent_id
                WHERE r.task_id NOT IN (SELECT task_id FROM matching.toloka_matching)
                GROUP BY r.task_id
                HAVING count(*) >= %(overlap)s
            ) as r
            JOIN matching.toloka_tasks as t ON t.id = r.task_id
            JOIN matching.endpoints as ep ON (
                (ep.supplier, ep.supplier_id) = (t.supplier, t.supplier_id)
            )
            """, {'overlap': overlap},
        )
        return cur.fetchall()


def apply_result(overlap=5, threshold=0.8):
    def get_max_matching(matching):
        return max(matching, key=itemgetter(1), default=(None, 0))

    def select_best(matching, point_type):
        code, score = matched = get_max_matching((x, s) for (x, _), s in matching)
        if score < threshold and point_type == 'city':
            return get_max_matching(
                (k, sum(vs)) for k, vs in group_by(matching, key=lambda x: x[0][1], value=itemgetter(1)).items()
            )
        return matched

    with psycopg2.connect(Config.CONNECTION_STRING) as conn, conn.cursor(cursor_factory=RealDictCursor) as cur:
        results = _fetch_unhandled_results(cur, overlap)
        for x in results:
            assert len(x['matching']) >= overlap
            x['matched'] = select_best(
                [(k, v / overlap) for k, v in Counter(map(tuple, x['matching'])).items()], x['type'],
            )
        dt_now = datetime.now(tz=timezone.utc)
        insert_many(cur, 'matching.toloka_matching', [
            {'task_id': x['task_id'], 'code': x['matched'][0], 'score': x['matched'][1], 'timestamp': dt_now}
            for x in results
        ])
