import functools
import time
import collections
from pprint import pformat

from absl import app, logging
from tqdm.auto import tqdm
import langcodes
import psycopg2
import pycountry
import pandas as pd 

from mgst_data.config import DB_HOST, DB_USER, DB_PASSWORD
from mgst_data.utils import upsert_data, query

UNKNOWNS = ('other', 'none', '', 'Unknown', 'NULL', "UNKNOWN")

# Build Language Map


def build_lang_map(con):
    langs = query(con, """
        SELECT distinct language FROM mgst.marionette_metadata_language
        WHERE type != 'country'
        AND language IS NOT NULL
    """, fetch=True)

    lang_map = {}

    trusted_lang_map = {
        'Ilocano': 'ilo',
        'العر': 'ara',
        'हि': 'hin',
        'Françai': 'fra',
        'Portugue': 'por',
        'Tiếng': 'vie',
        'Русс': 'rus',
        'Portugu': 'por',
        'Esperant': 'epo',
        'Bahasa I': 'ind',
        'Kinyarwa': 'kin',
        'Afrikaan': 'afr',
        'Lithuani': 'lit',
        'Azerbaij': 'aze',
        'Sinhales': 'sin',
        'ไท': 'tha',
        'Bahasa M': 'msa',
        'Mandarin': 'zho',
        '日本': 'jpn',
        '한국': 'kor',
        'Bokmal': 'nob'
    }

    for lang_og, in langs:
        lang = lang_og.strip()

        if lang in UNKNOWNS:
            lang_map[lang] = 'Unknown'
            continue

        detects = [
            pycountry.languages.get(alpha_3=trusted_lang_map.get(lang)),
            pycountry.languages.get(alpha_2=lang),
            pycountry.languages.get(alpha_3=lang),
            pycountry.languages.get(name=lang),
        ]

        try:
            detects.append(pycountry.languages.lookup(lang))
        except LookupError:
            pass

        try:
            detects.append(pycountry.languages.get(
                alpha_2=langcodes.Language.get(lang).language))
        except:
            pass

        try:
            detects.append(pycountry.languages.get(
                alpha_2=langcodes.find(lang).language))
        except:
            pass

        detect = functools.reduce(
            lambda l, x: l if l and l[1] else x, enumerate(detects), None)

        if not detect[1]:
            logging.warning('Failed to map lang: %s', lang)
            continue

        lang_map[lang_og] = detect[1].name

    return lang_map


# Build Country Map
def build_country_map(con):
    countries = query(con, """
    SELECT distinct country FROM mgst.marionette_metadata_country
    WHERE type != 'language'
    AND country IS NOT NULL
    """, fetch=True)

    country_map = {}

    trusted_country_map = {
    }

    for c_og, in countries:
        c = c_og.strip()

        if c in UNKNOWNS:
            country_map[c] = ['Unknown', '']
            continue

        detects = [
            pycountry.countries.get(alpha_3=trusted_country_map.get(c)),
            pycountry.countries.get(alpha_2=c),
            pycountry.countries.get(alpha_2=c.upper()),
            pycountry.countries.get(alpha_3=c),
            pycountry.countries.get(name=c),
        ]

        try:
            detects.append(pycountry.countries.lookup(c))
        except LookupError:
            pass

        detect = functools.reduce(
            lambda l, x: l if l and l[1] else x, enumerate(detects), None)

        if not detect[1]:
            logging.error('Failed to map country: %s', c)
            continue

        country_map[c_og] = [detect[1].name, detect[1].alpha_2]

    return country_map


def main(_argv):
    start_time = time.time()

    con = psycopg2.connect(dbname='product', host=DB_HOST,
                           port='5439', user=DB_USER, password=DB_PASSWORD)

    logging.info("Building Normalization Mapping")

    lang_map = build_lang_map(con)
    country_map = build_country_map(con)

    logging.info("Loaded %d Langauges and %d Countries",
                 len(lang_map), len(country_map))

    channels = pd.read_sql("""
    WITH l_data AS (SELECT
        channel_id,
        platform,
        language,
        type,
        ROW_NUMBER() OVER (PARTITION BY channel_id, platform ORDER BY (CASE
            WHEN (language IS NULL OR language IN %s) THEN 9999
            WHEN type = 'trusted' THEN 1
            WHEN type = 'source' THEN 2
            WHEN type = 'moneyball' THEN 3
            WHEN type = 'comprehend' THEN 4
            WHEN type = 'whatlango' THEN 5
        END)) as rank
    FROM mgst.marionette_metadata_language
    WHERE type in ('trusted', 'source', 'comprehend', 'moneyball', 'whatlango')
    GROUP BY 1, 2, 3, 4),
    c_data AS (SELECT
        channel_id,
        platform,
        country,
        type,
        ROW_NUMBER() OVER (PARTITION BY channel_id, platform ORDER BY (CASE
            WHEN (country IS NULL OR country IN %s) THEN 9999
            WHEN type = 'trusted' THEN 1
            WHEN type = 'source' THEN 2
            WHEN type = 'moneyball' THEN 3
        END)) as rank
    FROM mgst.marionette_metadata_country
    WHERE type in ('trusted', 'source', 'moneyball')
    GROUP BY 1, 2, 3, 4),
    link AS (SELECT
        channel_id,
        platform,
        channel_name,
        link
    FROM mgst.marionette_metadata_link
    ),
    channels AS (SELECT
        channel_id,
        platform
        FROM (
            SELECT channel_id, platform FROM l_data
            UNION
            SELECT channel_id, platform FROM c_data
            UNION
            SELECT channel_id, platform FROM link
        )
        GROUP BY 1, 2
    )
    SELECT
        t1.channel_id,
        t1.platform,
        tc.country,
        tc.type as country_type,
        tl.language,
        tl.type as language_type,
        link.channel_name,
        link.link
    FROM channels t1
    LEFT JOIN l_data tl
        ON (t1.channel_id = tl.channel_id AND t1.platform = tl.platform AND tl.rank = 1)
    LEFT JOIN c_data tc
        ON (t1.channel_id = tc.channel_id AND t1.platform = tc.platform AND tc.rank = 1)
    LEFT JOIN link
        ON (t1.channel_id = link.channel_id AND t1.platform = link.platform)
    """, con, params=[UNKNOWNS, UNKNOWNS])

    logging.info("Loaded %d Channels", len(channels))

    def insert(data):
        upsert_data(con,
                    'mgst.marionette_metadata_prod',
                    ['channel_id', 'platform', 'country', 'country_code', 'country_type', 'language',
                     'language_type', 'channel_name', 'link'],
                    ['channel_id', 'platform'],
                    data, page_size=len(data))

    n_total = 0
    n_country = collections.Counter()
    n_language = collections.Counter()

    inserts = []
    for _idx, channel_id, platform, country, country_type, language, langauge_type, channel_name, link in tqdm(channels.itertuples(), total=len(channels)):
        c = country_map.get(country)
        l = lang_map.get(language)

        n_total += 1
        if c:
            n_country[country_type] += 1

        if l:
            n_language[langauge_type] += 1

        inserts.append([
            channel_id,
            platform,
            c[0] if c else None,
            c[1] if c else None,
            country_type,
            l,
            langauge_type,
            channel_name,
            link
        ])

        if len(inserts) > 500000:
            insert(inserts)
            inserts = []

    insert(inserts)
    inserts = []

    logging.info("Updated Languages: %s", pformat(n_language))
    logging.info("Updated Countries: %s", pformat(n_country))

    logging.info("Done %ds", time.time() - start_time)


if __name__ == '__main__':
    app.run(main)
