from typing import List, Dict, Tuple, AnyStr
from multiprocessing import Pool
from functools import partial
import logging
import os
from io import BytesIO
from tqdm import tqdm
import openpyxl
from dataclasses import dataclass
import yt.wrapper
from urllib.parse import urlparse


from customer_service.ml.lib.data.utils import (
    download_knowledges_snapshot_http,
    levenshtein_distance,
    print_matching_stats,
    download_file,
    get_yt_client
)


logger = logging.getLogger(__name__)


@dataclass
class RemapperItem:
    """Class for keeping track of remapper item"""
    url: str
    label: int
    text: str


def _remap(remapper, col_url, col_text, row):
    if (row[col_url] in remapper and remapper[row[col_url]]):
        row[col_url] = remapper[row[col_url]].url
        row[col_text] = remapper[row[col_url]].text

    yield row


class KnowledgesRemapper:

    def __init__(self, proxy='hahn', token=None, num_col_map=None):
        self.client = get_yt_client(proxy, token or os.getenv("YT_TOKEN"))
        self.num_col_map = num_col_map
        self.remapper = {}

        if not self.num_col_map:
            self.num_col_map = {
                0: 'know1',
                1: 'know2',
                2: 'label1',
                3: 'label2',
                4: 'tf-idf_similarity',
                5: 'url1',
                6: 'url2'
            }

        self.col_num_map = {v: k for k, v in self.num_col_map.items()}

    def create_remapper_from_xlsx(self, xlsx_yt_path, color_code_remove='FFFF0000'):
        xlsx_file = download_file(xlsx_yt_path, self.client)
        wb = openpyxl.load_workbook(BytesIO(xlsx_file))
        sh = wb.active
        for row in sh.iter_rows():
            url1 = row[self.col_num_map['url1']].value
            url2 = row[self.col_num_map['url2']].value
            color1 = row[self.col_num_map['url1']].fill.start_color.index
            color2 = row[self.col_num_map['url2']].fill.start_color.index
            label1 = row[self.col_num_map['label1']].value
            label2 = row[self.col_num_map['label2']].value
            know1 = row[self.col_num_map['know1']].value
            know2 = row[self.col_num_map['know2']].value

            if color1 != color_code_remove and color2 != color_code_remove:
                self.remapper[url1] = RemapperItem(url1, label1, know1)
                self.remapper[url2] = RemapperItem(url2, label2, know2)

            elif color1 == color_code_remove and color2 == color_code_remove:
                self.remapper[url1] = None
                self.remapper[url2] = None

            elif color1 == color_code_remove or color2 == color_code_remove:
                correct_val = RemapperItem(
                    url2 if color1 == color_code_remove else url1,
                    label2 if color1 == color_code_remove else label1,
                    know2 if color1 == color_code_remove else know1
                )
                self.remapper[url1] = self.remapper.get(url1, correct_val)
                self.remapper[url2] = self.remapper.get(url2, correct_val)

    def update_remapper_from_yt_table(self, source_table, col_url='KnowledgeId', col_text='Knowledge'):
        logger.info("Updating remapper")
        for row in tqdm(self.client.read_table(source_table)):
            if row[col_url] in self.remapper and self.remapper[row[col_url]]:
                self.remapper[row[col_url]].text = row[col_text]

    def remap_knowledges(self, source_table, destination_table,
                         col_url='KnowledgeId', col_text='Knowledge'):
        if not self.remapper or len(self.remapper) == 0:
            raise ValueError("Remapper is not created")

        self.update_remapper_from_yt_table(source_table, col_url, col_text)

        remap = partial(_remap, self.remapper, col_url, col_text)
        self.client.run_map(remap, source_table, destination_table)


def parse_snapshot(snapshot: List[dict], product: str) -> List[dict]:
    """
    Function that parses initial snapshot to list of actual knowledges
    Args:
        snapshot: list with all actual knowledges
        product: product tag to filter knowledges
    Returns:
        list with actual knowledges for corresponding product
    """
    knowledges = []
    for x in tqdm(snapshot):
        assert 'supertag' in x
        categories = x['supertag'].split('/')
        x['n_levels'] = len(categories)
        if len(categories) > 3 and categories[3] == product:
            knowledges.append(x)
    return knowledges


def construct_knowledges_graph(knowledges: List[dict]) -> Dict[str, List[str]]:
    """
    Function that constructs actual knowledges graph from list of knowledges
    Args:
        knowledges: list of knowledges

    Returns:
        knowledges graph
    """
    for knowledge in knowledges:
        assert 'supertag' in knowledge
        assert 'n_levels' in knowledge

    knowledges = sorted(knowledges, key=lambda k: (k['n_levels'], k['supertag']))

    G = {}
    for knowledge in knowledges:
        node = knowledge['supertag']
        if node not in G:
            G[node] = []

        parent = '/'.join(node.split('/')[:-1])
        if parent in G:
            G[parent].append(node)

    return G


def prepare_actual_knowledges_graph(url: str, product: str) -> Dict[str, List[str]]:
    """
    Function that prepares list of actual kmowledges
    Args:
        url: link to snapshot with knowledges
        product: name of the product: uslugi, station, passport etc

    Returns:
    list of actual knowledges
    """
    snapshot = download_knowledges_snapshot_http(url)
    knowledges = parse_snapshot(snapshot, product)
    G = construct_knowledges_graph(knowledges)
    return G


def convert_knowledge_to_categories(knowledge: str, start_category: str) -> List[str]:
    """
    Function that converts knowledge to list of categories beginning from
    start_category
    Args:
        knowledge: url of knowledge
        start_category: name of category to cut url

    Returns:
        list of categories
    """
    categories = [x for x in knowledge.split('/') if x]
    for i, category in enumerate(categories):
        if category == start_category:
            return categories[i + 1:]
    return categories


def get_knowledges_from_graph(graph: Dict[str, List[str]], category: str) -> List[List[str]]:
    """
    Function that gets list of knowledges from graph in a format suitable for further matching
    Args:
        graph: graph of knowledges
        category: category name

    Returns:
        list of knowledges
    """
    knowledges = []
    for knowledge in tqdm(graph.keys()):
        knowledges.append(
            convert_knowledge_to_categories(knowledge, category)
        )

    return knowledges


def prepare_actual_knowledges(url: str, product: str) -> List[List]:
    G = prepare_actual_knowledges_graph(url, product)
    actual_knowleges = get_knowledges_from_graph(G, product)
    return actual_knowleges


def parse_historical_knowledge(knowledge: str) -> List[str]:
    """
    Function to parse knowledge url from ticket
    Args:
        knowledge: knowledge url

    Returns:
        List of categories
    """
    categories = [x for x in knowledge.split('/') if x]

    if not categories:
        return []

    if categories[-1] == 'index.html':
        categories.pop()

    return categories


def get_knowledges_history(df):
    assert 'knowledgeId' in df

    knowledges = []
    for knowledge in tqdm(df['knowledgeId']):
        knowledges.append(
            parse_historical_knowledge(knowledge)
        )

    return knowledges


def preprocess_data(df):
    cols = ['knowledgeId', 'knowledgeTitle', 'messageTime', 'text']
    for col in cols:
        df[col] = df[col].str.decode("utf-8")
    return df


def knowledges_distance(knowledge_historical: List[str], knowledge_actual: List[str]):
    w = 1
    distance = 0

    if len(knowledge_actual) != len(knowledge_historical):
        return 999

    for i, (historical, actual) in enumerate(zip(knowledge_historical, knowledge_actual)):
        if historical != '#':
            distance += levenshtein_distance(historical, actual) * (w / 2 ** i)

    return distance


def find_best_match(knowledges_actual: List[List], knowledge: List) -> Tuple[float, List[AnyStr], List[AnyStr]]:
    """
    Function that finds best match from daas knowledges
    Args:
        knowledges_actual: daas knowledges list
        knowledge: knowledge to match

    Returns:
        tuple: best distance, best match
    """
    best_distance = 1e9
    best_match = []
    knowledge_matched = knowledge

    for actual in knowledges_actual:
        if not actual or not knowledge:
            continue

        if actual[-1] == knowledge[-1]:
            best_distance = 0
            best_match = actual
            knowledge_matched = knowledge
            break

        curr_distance = knowledges_distance(knowledge, actual)
        if curr_distance < best_distance:
            best_distance = curr_distance
            best_match = actual
            knowledge_matched = knowledge
            if best_distance == 0:
                break

        shift_left = knowledge[1:] + ['#']
        curr_distance = knowledges_distance(shift_left, actual)
        if curr_distance < best_distance:
            best_distance = curr_distance
            best_match = actual
            knowledge_matched = shift_left
            if best_distance == 0:
                break

        shift_right = ['#'] + knowledge
        curr_distance = knowledges_distance(shift_right, actual)
        if curr_distance < best_distance:
            best_distance = curr_distance
            best_match = actual
            knowledge_matched = shift_right
            if best_distance == 0:
                break

    return best_distance, best_match, knowledge_matched


def map_knowledges(knowledge_id: str, actual_knowledges: List[List]) -> Tuple:
    """
    Function for matching knowledges
    Args:
        knowledge_id: knowledge to match
        actual_knowledges: list of actual knowledges

    Returns:
        best_distance, best_match, knowledge_matched
    """
    categories = parse_historical_knowledge(knowledge_id)
    return find_best_match(actual_knowledges, categories)


def match_knowledges(knowledges, knowledges_actual):
    # matched = []
    # for knowledge in tqdm(knowledges):
    #     best_distance, best_match = find_best_match(knowledges_actual, knowledge)
    #     matched.append((best_distance, best_match))
    find_best_match_with_actual = partial(find_best_match, knowledges_actual)

    p = Pool()
    matched = p.map(find_best_match_with_actual, knowledges)
    return matched


def match_knowledges_to_df(df, actual_knowledges):
    df = preprocess_data(df)
    knowledges_history = get_knowledges_history(df)
    matched = match_knowledges(knowledges_history, actual_knowledges)
    print_matching_stats(matched)
    df['distance'] = [x[0] for x in matched]
    df['matched_categories'] = [x[1] for x in matched]
    return df


@yt.wrapper.yt_dataclass
class CasesDataRaw:
    case: str
    action: str
    incoming_number: str
    url: str


def convert_cases_from_xlsx_to_yt(xlsx_yt_path, destination_table, proxy='hahn', token=None):
    client = get_yt_client(proxy, token or os.getenv("YT_TOKEN"))
    xlsx_file = download_file(xlsx_yt_path, client)
    wb = openpyxl.load_workbook(BytesIO(xlsx_file))
    sh = wb.active
    rows = []
    for row in sh.iter_rows(min_row=2):
        path = urlparse(row[3].value).path
        rows.append(
            CasesDataRaw(
                str(row[0].value),
                str(row[1].value),
                str(row[2].value),
                str(path)
            )
        )

    client.write_table_structured(
        destination_table,
        CasesDataRaw,
        rows
    )
