import logging
from datetime import datetime, timedelta

import yenv
import yt.wrapper as yt
from flask import current_app as app
from flask_script import Command, Option
from nile.api.v1 import filters
from qb2.api.v1 import filters as qf

from jafar import advisor_mongo, shared_cache, clickhouse
from jafar.utils import date_range
from jafar.utils.io import get_cluster
from jafar.utils.iter import take
from jafar_yt.usage_stats import usage_stats_mapper_ch, EVENT_NAMES

logger = logging.getLogger(__name__)

ONLINE_TABLE_NAME_FORMAT = '%Y-%m-%dT%H:%M:%S'
OFFLINE_TABLE_NAME_FORMAT = '%Y-%m-%d'

def cleanup_usage_stats(start_date=None, end_date=None):
    """
    Deletes documents in range of (start_date, end_date).
    Allows None values for open intervals.
    """
    assert start_date or end_date, \
        "Either start_date or end_date must be non-null"

    def drop_partitions(table, start_date=None, end_date=None):
        conditions = [
            "database = '{db}'".format(db=app.config['CLICKHOUSE_DATABASE']),
            # table name for materialized view partitions starts with ".inner."
            "(table = '{table}' OR table = '.inner.{table}')".format(table=table),
            "active"
        ]
        if start_date:
            conditions.append("max_date >= '{max_date}'".format(max_date=start_date))
        if end_date:
            conditions.append("min_date < '{min_date}'".format(min_date=end_date))

        partitions = clickhouse.execute("""
            SELECT partition
            FROM system.parts
            WHERE
                {conditions}
            GROUP BY partition
        """.format(conditions=" AND ".join(conditions)))
        for partition_name, in partitions:
            clickhouse.execute("ALTER TABLE {db}.{table} DROP PARTITION '{partition}'".format(
                db=app.config['CLICKHOUSE_DATABASE'],
                table=table,
                partition=partition_name.strip("'")
            ))
            logging.info('Deleted partition: %s.%s' % (table, partition_name))

    drop_partitions(app.config['CLICKHOUSE_USAGE_LOGS_TABLE'], start_date, end_date)
    drop_partitions(app.config['CLICKHOUSE_USAGE_HOURLY_TABLE'], start_date, end_date)
    drop_partitions(app.config['CLICKHOUSE_USAGE_WEEKLY_TABLE'], start_date, end_date)
    drop_partitions(app.config['CLICKHOUSE_USAGE_COUNTERS_TABLE'], start_date, end_date)


def get_online_metrika_tables():
    last_table_name = shared_cache.get('USAGE_STATS_LAST_TABLE')
    if last_table_name:
        last_table_time = datetime.strptime(last_table_name, ONLINE_TABLE_NAME_FORMAT)
        logger.info("Found previously processed online table: %s", last_table_name)
    else:
        last_table_time = None

    # collect all tables later then `last_table_time`
    tables = []
    for table_name in yt.list(app.config['YT_METRIKA_PATH_30_MIN']):
        table_time = datetime.strptime(table_name, ONLINE_TABLE_NAME_FORMAT)
        if last_table_time is None or table_time > last_table_time:
            tables.append(table_name)
    return tables


def get_offline_metrika_tables(start_date, end_date):
    tables = []
    for date in date_range(start_date, end_date):
        tables.append(date.strftime(OFFLINE_TABLE_NAME_FORMAT))
    return tables


def get_users():
    return {profile['_id'].get_hex() for profile in advisor_mongo.db.profile.find()}


def get_usage_stats(path, table_names):
    logger.info('Loading usage statistics from the following tables: %s', table_names)
    with yt.TempTable(path=app.config['YT_PATH_TMP']) as dst:
        cluster = get_cluster()
        job = cluster.job()
        stream = job.table(
            '{path}/{tables}'.format(
                path=path,
                tables='{' + ','.join(table_names) + '}'
            )
        ).filter(
            filters.equals('APIKey', app.config['YT_LAUNCHER_API_KEY']),
            qf.one_of('EventName', EVENT_NAMES),
            qf.defined('EventValue'),
        ).project(
            'DeviceID',
            'EventName',
            'EventValue',
            'Latitude',
            'Longitude',
            'ReceiveDate',
            'EventDateTime'
        )

        if yenv.type != 'production':
            logger.info('Leaving usage stats for test devices only')
            users = get_users()
            stream = stream.filter(qf.one_of('DeviceID', users))

        schema = {
            'receive_date': str,
            'user': str,
            'item': str,
            'class_name': str,
            'event_name': str,
            'event_value': str,
            'event_datetime': str,
            'latitude': float,
            'longitude': float
        }

        stream.map(
            usage_stats_mapper_ch
        ).put(
            dst, schema=schema
        )
        job.run()
        logger.info("Done collecting online usage stats. Reading...")
        return yt.read_table(dst, format='json')


def ensure_ch_table():
    clickhouse.execute("""
        CREATE TABLE IF NOT EXISTS {db}.{usage_logs}
        ON CLUSTER {cluster} (
            date Date,
            user UUID,
            item String,
            class_name String,
            event_name String,
            event_value String,
            event_date DEFAULT toDate(local_datetime),
            local_datetime DateTime,
            latitude Nullable(Float64),
            longitude Nullable(Float64)
        ) ENGINE = ReplicatedMergeTree('/clickhouse/tables/{usage_logs}', '{{replica}}')
        PARTITION BY date ORDER BY user
    """.format(db=app.config['CLICKHOUSE_DATABASE'],
               cluster=app.config['CLICKHOUSE_CLUSTER'],
               usage_logs=app.config['CLICKHOUSE_USAGE_LOGS_TABLE']))

    clickhouse.execute("""
        CREATE MATERIALIZED VIEW IF NOT EXISTS {db}.{usage_hourly}
        ON CLUSTER {cluster}
        ENGINE = ReplicatedSummingMergeTree('/clickhouse/tables/{usage_hourly}', '{{replica}}')
        PARTITION BY date
        ORDER BY (user, place, subplace, date, item, class_name, hour)
        SETTINGS index_granularity=256
        AS SELECT
            user,
            item,
            class_name,
            toHour(local_datetime) as hour,
            count() as app_launch_count,
            extract(visitParamExtractRaw(event_value, 'place'), '{{"(.*?)":{{.*?') as place,
            extract(visitParamExtractRaw(event_value, 'place'), '{{".*?":{{"(.*?)"') as subplace,
            date
        FROM {db}.{usage_logs}
        WHERE event_name = 'app_launch'
        GROUP BY (user, place, subplace, date, item, class_name, hour)
    """.format(db=app.config['CLICKHOUSE_DATABASE'],
               cluster=app.config['CLICKHOUSE_CLUSTER'],
               usage_hourly=app.config['CLICKHOUSE_USAGE_HOURLY_TABLE'],
               usage_logs=app.config['CLICKHOUSE_USAGE_LOGS_TABLE']))

    clickhouse.execute("""
        CREATE MATERIALIZED VIEW IF NOT EXISTS {db}.{usage_weekly}
        ON CLUSTER {cluster}
        ENGINE = ReplicatedSummingMergeTree('/clickhouse/tables/{usage_weekly}', '{{replica}}')
        PARTITION BY date
        ORDER BY (user, place, subplace, date, item, class_name, weekday)
        SETTINGS index_granularity=256
        AS SELECT
            user,
            item,
            class_name,
            toDayOfWeek(local_datetime) - 1 as weekday,
            count() as app_launch_count,
            extract(visitParamExtractRaw(event_value, 'place'), '{{"(.*?)":{{.*?') as place,
            extract(visitParamExtractRaw(event_value, 'place'), '{{".*?":{{"(.*?)"') as subplace,
            date
        FROM {db}.{usage_logs}
        WHERE event_name = 'app_launch'
        GROUP BY (user, place, subplace, date, item, class_name, weekday)
    """.format(db=app.config['CLICKHOUSE_DATABASE'],
               cluster=app.config['CLICKHOUSE_CLUSTER'],
               usage_weekly=app.config['CLICKHOUSE_USAGE_WEEKLY_TABLE'],
               usage_logs=app.config['CLICKHOUSE_USAGE_LOGS_TABLE']))

    clickhouse.execute("""
        CREATE MATERIALIZED VIEW IF NOT EXISTS {db}.{usage_counters}
        ON CLUSTER {cluster}
        ENGINE = ReplicatedSummingMergeTree('/clickhouse/tables/{usage_counters}', '{{replica}}')
        PARTITION BY date
        ORDER BY (user, item, date)
        SETTINGS index_granularity=64
        AS SELECT
            user,
            item,
            countIf(event_name='rec_view') as rec_view_count,
            countIf(event_name='App_install') as app_install_count,
            countIf(event_name='app_launch') as app_launch_count,
            date
        FROM {db}.{usage_logs}
        GROUP BY (user, item, date)
    """.format(db=app.config['CLICKHOUSE_DATABASE'],
               cluster=app.config['CLICKHOUSE_CLUSTER'],
               usage_counters=app.config['CLICKHOUSE_USAGE_COUNTERS_TABLE'],
               usage_logs=app.config['CLICKHOUSE_USAGE_LOGS_TABLE']))


def save_usage_stats(stats, batch_size=100000):
    logger.info(
        'Start updating usage logs clickhouse table',
    )

    batch = take(stats, batch_size)
    count = 0
    while batch:
        values = []
        for stat in batch:
            stat['date'] = datetime.strptime(stat.pop('receive_date'), '%Y-%m-%d').date()
            stat['local_datetime'] = datetime.strptime(stat.pop('event_datetime'), '%Y-%m-%d %H:%M:%S')
            stat['event_date'] = stat['local_datetime'].date()
            values.append(stat)

        clickhouse.execute("INSERT INTO {db}.{usage_logs} VALUES ".format(
            db=app.config['CLICKHOUSE_DATABASE'],
            usage_logs=app.config['CLICKHOUSE_USAGE_LOGS_TABLE']
        ), values, types_check=False)

        count += len(batch)
        logger.info("%d records written", count)
        batch = take(stats, batch_size)

    logger.info('Usage logs inserted: %d new records', count)


class UpdateUsageStats(Command):
    option_list = (
        Option('--start_date', dest='start_date', default=None, action='store'),
        Option('--end_date', dest='end_date', default=None, action='store'),
        Option('--online_table', dest='online_table', default=None, action='store'),
    )

    @staticmethod
    def process_online_stats(tables):
        stats = get_usage_stats(app.config['YT_METRIKA_PATH_30_MIN'], tables)
        save_usage_stats(stats)

    @staticmethod
    def update_online_stats():
        """
        Downloads fresh data from metrika 30min (aka "online") tables.
        This is the default mode which should run periodically.
        """
        tables = get_online_metrika_tables()
        if not tables:
            logger.info('Nothing to load')
            return

        cleanup_usage_stats(
            end_date=(datetime.today() - timedelta(days=app.config['USAGE_STATS_SAVE_HORIZON'])).strftime('%Y-%m-%d'))

        UpdateUsageStats.process_online_stats(tables)

        # save last processed table
        last_table = max(tables, key=lambda name: datetime.strptime(name, ONLINE_TABLE_NAME_FORMAT))
        shared_cache.set('USAGE_STATS_LAST_TABLE', str(last_table), timeout=-1)
        logger.info("Updated last processed online usage stats table: %s", last_table)

    @staticmethod
    def update_offline_stats(start_date, end_date):
        tables = get_offline_metrika_tables(start_date, end_date)

        if not tables:
            logger.info('Nothing to load')
            return

        cleanup_usage_stats(start_date, end_date)  # FIXME date in metrica table name != date of event
        stats = get_usage_stats(app.config['YT_METRIKA_PATH_1_DAY'], tables)
        save_usage_stats(stats)

    def run(self, start_date, end_date, online_table):
        yt.update_config(app.config['YT_CONFIG'])
        ensure_ch_table()

        if online_table:
            UpdateUsageStats.process_online_stats([online_table])
        elif start_date or end_date:
            assert start_date and end_date, \
                "Both start_date and end_date are required for offline update"
            if start_date:
                start_date = datetime.strptime(start_date, '%Y-%m-%d').date()
            if end_date:
                end_date = datetime.strptime(end_date, '%Y-%m-%d').date()
            logger.info("Performing offline update since %s till %s", start_date, end_date)
            self.update_offline_stats(start_date, end_date)
        else:
            logger.info("Performing online update")
            self.update_online_stats()
