import logging
from typing import Tuple
import boto3
import gzip
from io import BytesIO
import json
import requests

from nile.api.v1 import clusters
from nile.api.v1.record import Record
from yql.api.v1.client import YqlClient
from yt.wrapper import YtClient
from library.python.vault_client.instances import Production as VaultClient


logger = logging.getLogger(__name__)


def get_secret_uuid(secret_name: str, secret_key: str) -> Tuple[str, str]:
    """
    Search and return secret from vault
    Args:
        secret_name: alias for secret in vault
        secret_key: key for secret in vault

    Returns:
        token uuid, token value
    """
    uuid, token = "", ""
    logger.info(f"Retrieving UUID for secret {secret_name}")
    client = VaultClient(decode_files=True)
    secrets = client.list_secrets(query=secret_name)

    if len(secrets) == 0:
        logger.error(f"Failed to find any secret with name {secret_name}")
    elif len(secrets) > 1:
        logger.error(f"Found {len(secrets)} secrets with name {secret_name}, can't pick one")
    else:
        uuid = secrets[0]["uuid"]
        logger.info(f"Using secret {secret_name} with uuid {uuid}, latest version {secrets[0]['last_secret_version']['version']}")
        token = client.get_version(uuid)["value"].get(secret_key, "")

        if not token:
            logger.error(f"Failed to get token by key {secret_key} from {uuid}")

    return uuid, token


def get_yql_client(token: str) -> YqlClient:
    """
    Function that returns YqlClient object
    Args:
        token: YQL token from vault

    Returns:
        YqlClient: YqlClient object

    """
    return YqlClient(token=token)


def get_yt_client(proxy: str, token: str) -> YtClient:
    """
    Args:
        proxy: YT cluster name
        token: YT token from vault

    Returns:
        YtClient: YtClient object
    """
    return YtClient(proxy=proxy, token=token)


def execute_query(yql_client: YqlClient, query_file: str, templating: dict):
    """
    Execute YQL query from a file with templating data
    Args:
        yql_client: YqlClient object
        query_file: file containing query
        templating: key-value pairs to fill in placeholders in query file
    """
    with open(query_file, encoding='utf-8') as file:
        sql = file.read()
        sql = sql.format(**templating)
    request = yql_client.query(sql, syntax_version=1)
    request.run()
    if not request.get_results().is_success:
        raise RuntimeError('YQL error while executing the query')


def download_knowledges_snapshot_s3(s3_bucket: str, s3_path: str):
    """
    Args:
        s3_bucket: name of a bucket in s3
        s3_path: path to knowledges file with .gz extension

    Returns:
        list of knowledges
    """
    s3 = boto3.client('s3')
    obj = s3.get_object(Bucket=s3_bucket, Key=s3_path)['Body'].read()

    with gzip.open(BytesIO(obj), 'rt', encoding='utf-8') as f:
        content = f.read()

    return json.loads(content)


def download_knowledges_snapshot_http(url: str):
    """
    Args:
        url: link to snapshot file with knowledges
    Returns:
        list of knowledges
    """
    r = requests.get(url, stream=True)

    with gzip.open(BytesIO(r.content), 'rt', encoding='utf-8') as f:
        content = f.read()

    return json.loads(content)


def levenshtein_distance(s1, s2):
    """
    Livenstein distance calculation
    """
    # https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance#Python
    if len(s1) < len(s2):
        return levenshtein_distance(s2, s1)

    # len(s1) >= len(s2)
    if len(s2) == 0:
        return len(s1)

    previous_row = range(len(s2) + 1)
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1
            deletions = current_row[j] + 1
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row

    return previous_row[-1]


def print_matching_stats(matching, threshold=0.6):
    N = len(matching) 
    matched_total = [x for x in matching if x[0] < threshold]
    matched_exactly = [x for x in matching if x[0] == 0 and x[1] == x[2]]
    matched_with_level_change = [x for x in matching if x[0] == 0 and x[1] != x[2]]
    matched_approximately = [x for x in matching if x[0] <= threshold and x[1] != x[2]]
    not_matched = [x for x in matching if x[0] >= threshold]

    print('Matching statistics')
    print(f'Total # of samples historical: {N}')
    print(f'Matched total: {len(matched_total)}, {round(len(matched_total) / float(N), 4) * 100}%')
    print(f'Matched exactly: {len(matched_exactly)}, {round(len(matched_exactly) / float(N), 4) * 100}%')
    print(f'Matched with url level change: {len(matched_with_level_change)}, {round(len(matched_with_level_change) / float(N), 4) * 100}%')
    print(f'Matched approximately: {len(matched_approximately)}, {round(len(matched_approximately) / float(N), 4) * 100}%')
    print(f'Not matched: {len(not_matched)}, {round(len(not_matched) / float(N), 4) * 100}%')


def create_target(categories, max_level=None):
    if max_level:
        categories = categories[:max_level]
    return '/'.join(categories)


def get_number_of_topics(datapath):
    cluster = clusters.YT("hahn")
    records = cluster.read(datapath)
    df = records.as_dataframe()
    assert 'topic' in df.columns
    return df['topic'].nunique()


def download_file(filepath, client):
    bytes = client.read_file(filepath).read()
    return bytes


def load_data(data_path, cluster='hahn'):
    cluster = clusters.YT(cluster)
    records = cluster.read(data_path, bytes_decode_mode='strict')
    return records.as_dataframe()


def save_data(df, data_path, cluster='hahn'):
    cluster = clusters.YT(cluster)
    records = [Record(**x) for x in df.to_dict(orient='records')]
    cluster.write(data_path, records)
