import logging
import re
from functools import partial

import psycopg2
import psycopg2.extras

from travel.rasp.bus.toloka_matching.src import toloka
from travel.rasp.bus.toloka_matching.src.config import Config
from travel.rasp.bus.toloka_matching.src.utils.date import parse_utc_datetime
from travel.rasp.bus.toloka_matching.src.utils.requests import AuthorizationSession
from travel.rasp.bus.toloka_matching.src.utils.sql import insert_many


DESCRIPTION_FIXES = [
        partial(re.compile(r'(\W|^)({})($|\W)'.format(string)).sub, r'\g<1>{}\g<3>'.format(repl))
        for string, repl in [('отв', 'поворот'), ('АВ', 'автовокзал')]
]


def _process_points(points):
    for p in points:
        for fix in DESCRIPTION_FIXES:
            for k in ('title', 'description'):
                p[k] = fix(p[k])


def _fetch_unmatched_points_for_supplier(cursor, supplier, limit):
    logging.info("Fetch poinst for supplier: %s", supplier)
    cursor.execute(
        """
        SELECT
            *
        FROM (
            SELECT
                distinct on (ep.description)
                ep.supplier as supplier,
                ep.supplier_id as supplier_id,
                ep.title as title,
                ep.type as type,
                ep.description as description,
                ep.latitude as latitude,
                ep.longitude as longitude
            FROM matching.endpoints as ep
            WHERE (ep.supplier, ep.supplier_id) NOT IN (
                SELECT supplier, supplier_id FROM matching.matching
                    UNION
                SELECT supplier, supplier_id FROM matching.toloka_tasks
            ) {}
            ORDER BY ep.description, ep.supplier, ep.supplier_id
        ) as s
        ORDER BY abs(char_length(description) - 90)
        LIMIT %(limit)s
        """.format(' AND ep.supplier = %(supplier)s' if supplier is not None else ''),
        {'supplier': supplier, 'limit': limit}
    )
    return cursor.fetchall()


def _fetch_unmatched_points(conn, suppliers, limit):
    with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
        points = []
        if suppliers is not None:
            for supplier in suppliers:
                points += _fetch_unmatched_points_for_supplier(cur, supplier, limit)
                limit -= len(points)
                if limit <= 0:
                    return points
        else:
            points += _fetch_unmatched_points_for_supplier(cur, None, limit)

        _process_points(points)

        return points


def _generate_task_data(points):
    task_data = [
        {k: x[k] for k in ('title', 'description', 'latitude', 'longitude', 'type')}
        for x in points
    ]
    return task_data


def _upload_tasks_to_toloka(conn, toloka_client, pool_id, task_data, points):
    tasks = toloka_client.create_pool_tasks(
        pool_id=pool_id, open_pool=False, inputss=task_data,
    )

    with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
        insert_many(
            cur, 'matching.toloka_tasks', [
                {
                    'id': t['id'],
                    'created': parse_utc_datetime(t['created']),
                    'supplier': p['supplier'],
                    'supplier_id': p['supplier_id'],
                    'pool_id': pool_id,
                } for t, p in zip(tasks, points)
            ],
        )


def create(pool_id, suppliers, limit):
    toloka_client = toloka.Client(AuthorizationSession(
        uri=Config.Toloka.URI,
        auth=Config.Toloka.AUTH,
    ))

    with psycopg2.connect(Config.CONNECTION_STRING) as db_connection:
        points = _fetch_unmatched_points(db_connection, suppliers, limit)
        logging.info("points %s", len(points))
        tasks_data = _generate_task_data(points)
        _upload_tasks_to_toloka(db_connection, toloka_client, pool_id, tasks_data, points)
        logging.info("Done")
