#!/usr/bin/python
# -*-coding: utf-8 -*-
# vim: sw=4 ts=4 expandtab ai

import ipaddress
import json
import logging
import md5
import struct
import time
from collections import defaultdict
from datetime import datetime
from user_agents import parse

from yql.api.v1.client import YqlClient
import ylock
from clickhouse_driver import Client
from yt.wrapper import YtClient, YtResponseError

# TODO solomon

PREPARE_LOG_YQL_TEMPLATE = """
$get_property = ($data, $property_name) -> {{ return (Yson::ConvertToString(
    Yson::YPath(
            Yson::ParseJson($data, Yson::Options(false AS Strict)), $property_name, Yson::Options(false AS Strict)
        )
    ))
}};

INSERT INTO `{prepared_log}` WITH TRUNCATE
SELECT
    VerifyLog.ServerTimestamp as ServerTimestamp,
    VerifyLog.Data as Data,
    VerifyLog.Yandexuid as Yandexuid,
    VerifyLog.ClientIP as ClientIP,
    VerifyLog.UserAgent as UserAgent,
    VerifyLog.Referer as Referer,
    CookieSync.ext_id as ext_id
FROM `{verify_log}` AS VerifyLog
LEFT JOIN (
    SELECT yuid, MAX_BY(ext_id, `timestamp`) AS ext_id
    FROM `{cookie_sync_table}`
    GROUP BY yuid
) as CookieSync
ON VerifyLog.Yandexuid == CookieSync.yuid
WHERE $get_property(Data, "/platformid") == "{platform_id}";
"""


def yabs_md5(value):
    m = md5.new()
    m.update(value)

    arr = struct.unpack('>4L', m.digest())
    hi = arr[1] ^ arr[3]
    lo = arr[0] ^ arr[2]

    r = (hi << 32) | lo

    if r is None:
        r = 0

    return r


class Verifier():
    def __init__(
            self,
            platform_id,
            platform_name,
            ch_host,
            ch_database,
            ch_user,
            ch_password,
            tables_prefix,
            custom_fields,
            uid_salt,
            keep_tables_count,
            keep_processed_logs_records_days,
            bulk_insert_rows_count,
    ):
        self.platform_id = platform_id
        self.platform_name = platform_name
        self.ch_host = ch_host
        self.ch_database = ch_database
        self.ch_user = ch_user
        self.ch_password = ch_password
        self.tables_prefix = tables_prefix
        self.custom_fields = custom_fields
        self.uid_salt = uid_salt
        self.keep_tables_count = keep_tables_count
        self.keep_processed_logs_records_days = keep_processed_logs_records_days

        self.logger = logging.getLogger(platform_name)

        self.tables = {}

        self.rows_buffer = defaultdict(list)
        self.bulk_insert_rows_count = bulk_insert_rows_count

        self._ch_connect()

    def _ch_connect(self):
        self.ch_client = Client(
            host=self.ch_host,
            database=self.ch_database,
            user=self.ch_user,
            password=self.ch_password,
            secure=True,
            verify=False
        )

    def _create_table_processed_logs(self):
        try:
            self.ch_client.execute("""
                CREATE TABLE IF NOT EXISTS processed_logs ON CLUSTER '{cluster}' (
                    Timestamp UInt64,
                    Log String
                )
                ENGINE = ReplicatedMergeTree('/processed_logs', '{replica}')
                PARTITION BY toDate(Timestamp/1000)
                ORDER BY (Timestamp);
                """)
        except Exception as e:
            self.logger.error("_create_table_processed_logs failed: {}".format(e))

    def processed_log(self, log):
        self._create_table_processed_logs()

        try:
            rows = self.ch_client.execute("SELECT Log from processed_logs WHERE Log = '{}' LIMIT 1;".format(log))
        except Exception as e:
            self.logger.error("processed_log failed: {}".format(e))

        return rows

    def update_processed_logs(self, log):
        self._create_table_processed_logs()

        try:
            self.ch_client.execute("INSERT INTO processed_logs VALUES", [[time.time()*1000, log.split('/')[-1]]])
        except Exception as e:
            self.logger.error("update_processed_logs failed: {}".format(e))

    def create_table(self, table_name):
        if table_name not in self.tables:
            query = """CREATE TABLE IF NOT EXISTS {table_name} ON CLUSTER '{{cluster}}' (
                    EventTime UInt64,
                    IPVer FixedString(4),
                    ClientIP String,
                    UserAgent String,
                    Referer String,
                    BrowserName String,
                    BrowserVersion String,
                    OSName String,
                    OSVersion String,
                    Uid String,
                    {custom_fields_with_types}
                )
                ENGINE = ReplicatedMergeTree('/{table_name}', '{{replica}}')
                PARTITION BY toDate(EventTime/1000)
                ORDER BY (ClientIP, UserAgent, Referer, BrowserName, BrowserVersion, OSName, OSVersion, Uid,
                {custom_fields_list});""".format(
                    table_name=table_name,
                    custom_fields_with_types=", ".join(["{} String".format(field.replace('.', '_')) for field in self.custom_fields]),
                    custom_fields_list=", ".join([field.replace('.', '_') for field in self.custom_fields])
                )

            try:
                self.ch_client.execute(query)
            except Exception as e:
                self.logger.error("create_table failed: {}".format(e))

            # remember table fields order
            query = "DESCRIBE TABLE {}".format(table_name)
            try:
                self.tables[table_name] = self.ch_client.execute(query)
            except Exception as e:
                self.logger.error("describe at create_table failed: {}".format(e))

    def insert_row(self, row, custom_data):
        table, prepared_row = self.prepare_row(row, custom_data)
        self.rows_buffer[table].append(prepared_row)

        if len(self.rows_buffer[table]) >= self.bulk_insert_rows_count:
            self.insert_buffer(table, self.rows_buffer[table])
            self.rows_buffer[table] = []

    def flush_buffer(self):
        for table in self.rows_buffer:
            if len(self.rows_buffer[table]) > 0:
                self.insert_buffer(table, self.rows_buffer[table])
                self.rows_buffer[table] = []

    def insert_buffer(self, table, rows_buffer):
        begin_ts = time.time()
        self.logger.info("inserting {} rows to table {}".format(len(rows_buffer), table))
        try:
            self.ch_client.execute("INSERT INTO {} VALUES".format(table), rows_buffer)
            end_ts = time.time()
            self.logger.info("inserted in {}s".format(int(end_ts-begin_ts)))
        except Exception as e:
            self.logger.error("failed: {}".format(e))

    def prepare_row(self, row, custom_data):
        row_ts = datetime.fromtimestamp(row["ServerTimestamp"]/1000)  # we can use utcfromtimestamp here
        table = row_ts.strftime('{}%Y%m%d'.format(self.tables_prefix))
        self.create_table(table)

        # prepare default fields
        ip_ver = "IPv4"
        try:
            if type(ipaddress.ip_address(unicode(row["ClientIP"]))) == ipaddress.IPv6Address:
                ip_ver = "IPv6"
        except ValueError:
            self.logger.warning("strange ClientIP '{}'".format(unicode(row["ClientIP"])))

        yuid = row["Yandexuid"]
        ext_id = row["ext_id"]
        uid = ext_id if ext_id else -yabs_md5(yuid + self.uid_salt)
        browser_name = None
        browser_version = None
        os_name = None
        os_version = None
        try:
            user_agent = parse(row["UserAgent"])
            browser_name = user_agent.browser.family
            browser_version = user_agent.browser.version_string
            os_name = user_agent.os.family
            os_version = user_agent.os.version_string
        except Exception as e:
            self.logger.error("can't parse useragent {} got error {}".format(row["UserAgent"], e))

        ch_row = {
            "EventTime": row["ServerTimestamp"],
            "IPVer": ip_ver,
            "ClientIP": row["ClientIP"],
            "UserAgent": row["UserAgent"],
            "Referer": row["Referer"],
            "BrowserName": '' if browser_name is None else browser_name,
            "BrowserVersion": '' if browser_version is None else browser_version,
            "OSName": '' if os_name is None else os_name,
            "OSVersion": '' if os_version is None else os_version,
            "Uid": str(uid),
        }

        # add custom fields
        for field in self.custom_fields:
            try:
                value = custom_data[field]
            except KeyError:
                value = ""
            ch_row[field.replace('.', '_')] = value

        # generate values
        ch_row_values = []
        for field in self.tables[table]:
            value = ch_row[field[0]]
            if value is None:
                value = ""
            ch_row_values.append(value)

        return table, ch_row_values

    def cleanup(self):
        self._create_table_processed_logs()

        # remove old records from processed_logs
        try:
            ts = int((time.time()-self.keep_processed_logs_records_days*24*60*60)*1000)
            self.ch_client.execute("ALTER TABLE processed_logs DELETE WHERE Timestamp < {}".format(ts))

        except Exception as e:
            self.logger.error("ALTER TABLE at cleanup failed: {}".format(e))

        # drop old tables
        try:
            rows = self.ch_client.execute("SHOW TABLES LIKE '{}%'".format(self.tables_prefix))
        except Exception as e:
            self.logger.error("SHOW TABLES at cleanup failed: {}".format(e))

        n = self.keep_tables_count
        for row in sorted(rows, reverse=True):
            n -= 1
            if n < 0:
                table = row[0]
                self.logger.info("drop table {}".format(table))
                try:
                    self.ch_client.execute("DROP TABLE {} ON CLUSTER '{{cluster}}'".format(table))
                except Exception as e:
                    self.logger.error("DROP TABLE {} at cleanup failed: {}".format(table, e))


class AvpExporter():
    def __init__(
            self,

            platform_id,
            platform_name,

            yt_cluster,
            yt_token,
            yql_token,
            yt_log_path,
            yt_workdir,

            cookie_sync_table,

            ch_host,
            ch_database,
            ch_user,
            ch_password,

            tables_prefix,
            keep_tables_count,
            keep_processed_logs_records_days,

            bulk_insert_rows_count,

            custom_fields,
            uid_salt):

        self.yt_cluster = yt_cluster
        self.yt_workdir = yt_workdir
        self.yt_token = yt_token
        self.yql_token = yql_token
        self.yt_log_path = yt_log_path

        self.yt_client = YtClient(yt_cluster, token=yt_token)
        self.yt_client.config['read_parallel']['enable'] = True
        self.yt_client.config['read_parallel']['max_thread_count'] = 32
        self.yt_client.config['read_parallel']['data_size_per_thread'] = 8 * 1024 * 1024

        self.yql_client = YqlClient(db=yt_cluster, token=yql_token)

        self.cookie_sync_table = cookie_sync_table

        self.verifier_params = {
            "platform_id": platform_id,
            "platform_name": platform_name,

            "ch_host": ch_host,
            "ch_database": ch_database,
            "ch_user": ch_user,
            "ch_password": ch_password,

            "tables_prefix": tables_prefix,
            "keep_tables_count": keep_tables_count,
            "keep_processed_logs_records_days": keep_processed_logs_records_days,

            "bulk_insert_rows_count": bulk_insert_rows_count,

            "custom_fields": custom_fields,
            "uid_salt": uid_salt,
        }

        self.logger = logging.getLogger("exporter")

    def do(self):
        self.logger.info("starting")

        verifier = Verifier(**self.verifier_params)
        verifier.cleanup()

        processed_logs_count = 0

        for log_name in sorted(self.yt_client.get(self.yt_log_path)):
            log = '/'.join([self.yt_log_path, log_name])

            lock_settings = {
                'backend': 'yt',
                'prefix': '//home/yabs/exporter',
                'token': self.yt_token,
            }
            manager = ylock.create_manager(**lock_settings)
            if not verifier.processed_log(log_name):
                self.logger.info("trying lock on log {} for {}".
                                 format(log, verifier.platform_name))
                with manager.lock(log, block=False) as acquired:
                    if acquired:
                        self.logger.info("processing log {} for {}".format(log, verifier.platform_name))
                        if self.process_log(log=log, verifier=verifier):
                            processed_logs_count += 1
                        else:
                            self.logger.fatal("fatal error while processing log {}".format(log))
                        """
                        we do process only 1 log_name at 1 task launch
                        so if lock is acquired, we break "for log_name..." cycle
                        """
                        break
                    else:
                        self.logger.info("skipping locked log {} for {}".format(log, verifier.platform_name))

            else:
                self.logger.debug("skipping log {}".format(log))

        self.logger.info("processed {} log(s)".format(processed_logs_count))

    def process_log(self, log, verifier):
        begin_ts = time.time()

        processed_rows = 0
        failed_rows = 0

        self.logger.info("preparing log %s in %s", log, self.yt_workdir)
        with self.yt_client.TempTable(self.yt_workdir, "yabs_avp_exporter_prepared_log_") as prepared_log:
            yql = PREPARE_LOG_YQL_TEMPLATE.format(
                cluster=self.yt_cluster,
                prepared_log=prepared_log,
                verify_log=log,
                cookie_sync_table=self.cookie_sync_table,
                platform_id=verifier.platform_id
            )

            self.logger.info("running prepare log yql:\n %s", yql)

            query = self.yql_client.query(yql, syntax_version=1)
            query.run()

            if not query.get_results().is_success:
                error_description = '\n'.join([str(err) for err in query.get_results().errors])
                self.logger.error(error_description)
                raise RuntimeError(error_description)

            self.logger.info("yql done, reading results..")

            try:
                rows = self.yt_client.read_table(prepared_log)

                for row in rows:
                    try:
                        custom_data = json.loads(row["Data"])
                        verifier.insert_row(row=row, custom_data=custom_data)
                        processed_rows += 1

                    except Exception as e:
                        self.logger.debug("strange row: %s, row: %s", e, row)
                        failed_rows += 1

                verifier.flush_buffer()
                verifier.update_processed_logs(log)

                end_ts = time.time()
                self.logger.info("processed log {} in {}s, processed_rows: {}, failed_rows: {}".format(log, int(end_ts-begin_ts), processed_rows, failed_rows))

                return True

            except YtResponseError as e:
                self.logger.error("yt_client.read_table failed: %s", e)
