import time
import pprint
from multiprocessing import Pool
from datetime import datetime

from absl import app, flags, logging
from absl.flags import FLAGS
import psycopg2
from psycopg2.extras import execute_values
from tqdm.auto import tqdm
import googleapiclient.discovery
import googleapiclient.errors

from mgst_data.config import DB_HOST, DB_USER, DB_PASSWORD, YOUTUBE_API_KEY
from mgst_data.utils import upsert_data, date_validator, query

flags.DEFINE_integer('n_channels', 10, 'top n channels')
flags.DEFINE_string('start_date', None, 'Starting Date (2019-01-01)')
flags.DEFINE_string('end_date', None, 'Ending Date (2019-02-01)')

flags.register_validator(
    'start_date', date_validator(), message='invalid date')
flags.register_validator('end_date', date_validator(), message='invalid date')

youtube = googleapiclient.discovery.build(
    "youtube", "v3", developerKey=YOUTUBE_API_KEY, cache_discovery=False)


def query_one(channel, debug=False):
    ids = list(map(lambda x: x.split('/')
                   [2], filter(lambda x: x.startswith('/channel/'), [channel])))
    users = list(map(lambda x: x.split(
        '/')[2], filter(lambda x: x.startswith('/user/'), [channel])))

    results = []
    if ids:
        request = youtube.channels().list(
            part="snippet",
            id=','.join(ids)
        )
        response = request.execute()
        results += response['items']

    if users:
        for user in users:
            request = youtube.channels().list(
                part="snippet",
                forUsername=user
            )
            response = request.execute()
            results += response['items']

    if not results:
        return

    x = results[0]

    if debug:
        pprint.pprint(x)

    return {
        "channel_id": channel,
        "channel_name": x['snippet'].get('title'),
        "country": x['snippet'].get('country'),
        "language": x['snippet'].get('defaultLanguage')
    }


def safe_query_one(channel):
    try:
        return query_one(channel[0])
    except Exception:
        logging.exception("query failed: %s", channel)


pool = Pool(10)


def main(_argv):
    start_time = time.time()
    logging.info("Processing top %s channels from %s to %s",
                 FLAGS.n_channels, FLAGS.start_date, FLAGS.end_date)

    con = psycopg2.connect(dbname='product', host=DB_HOST,
                           port='5439', user=DB_USER, password=DB_PASSWORD)

    channels = query(con, """
        SELECT t1.channel_id,
            ROW_NUMBER() OVER (
                                ORDER BY sum(ccv) DESC) AS hw_rank
        FROM tahoe.marionette_channel_prod t1
        LEFT JOIN mgst.marionette_metadata_language tl
            ON (t1.channel_id = tl.channel_id
                AND t1.platform = tl.platform
                AND tl.type IN ('source', 'moneyball') )
        LEFT JOIN mgst.marionette_metadata_country tc
            ON (t1.channel_id = tc.channel_id
                AND t1.platform = tc.platform
                AND tc.type IN ('source', 'moneyball') )
        WHERE (date BETWEEN %s AND %s)
        AND t1.platform = 'youtube'
        AND (tl.channel_id IS NULL
            AND tc.channel_id IS NULL)
        GROUP BY t1.channel_id
        ORDER BY hw_rank
        LIMIT %s
        """, [FLAGS.start_date, FLAGS.end_date, FLAGS.n_channels], True)

    data = list(tqdm(pool.imap(safe_query_one, channels), total=len(channels)))

    total = 0
    country = 0
    language = 0

    l_inserts = []
    c_inserts = []
    for result in tqdm(data):
        try:
            if not result:
                continue

            l_inserts.append([
                result['channel_id'],
                'youtube',
                result['language'],
                'source'
            ])

            c_inserts.append([
                result['channel_id'],
                'youtube',
                result['country'],
                'source'
            ])

            total += 1
            if result['country']:
                country += 1
            if result['language']:
                language += 1
        except Exception:
            logging.exception('channel %s failed', result)

    upsert_data(con,
                'mgst.marionette_metadata_language',
                ['channel_id', 'platform', 'language', 'type'],
                ['channel_id', 'platform', 'type'],
                l_inserts)

    upsert_data(con,
                'mgst.marionette_metadata_country',
                ['channel_id', 'platform', 'country', 'type'],
                ['channel_id', 'platform', 'type'],
                c_inserts)

    logging.info("Updated %d Users with %d Country and %d language",
                 total, country, language)
    logging.info("Done %ds", time.time() - start_time)


if __name__ == '__main__':
    app.run(main)
