#!/usr/bin/env python
# -*- coding: utf-8 -*-
from sandbox import sdk2
from sandbox.sandboxsdk import environments
import logging
import datetime
from sandbox.common.types.misc import DnsType


class YabsAwapsAntifraudExportPage(sdk2.Task):
    class Parameters(sdk2.Task.Parameters):
        yt_token_vault_name = sdk2.parameters.String("YQL robot token vault name", default="awaps_robot_yt_token", required=True)
        mssql_password = sdk2.parameters.String("MSSQL awaps user password", default="yt_loader_password", required=True)
        mssql_user = sdk2.parameters.String("MSSQL awaps user", default="yt_loader", required=True)
        mssql_server = sdk2.parameters.String("MSSQL server:port", default="sqllogc:1433", required=True)
        antifraud_pageid_dir_path = sdk2.parameters.String("Antifraud AwapsStat by PageID Dir Path", default="//home/antifraud/export/awaps/pageid", required=True)
        yt_cluster = sdk2.parameters.String("Yt cluster", default='hahn', required=True)

    class Requirements(sdk2.Task.Requirements):
        environments = (
            environments.PipEnvironment('yandex-yt'),
            environments.PipEnvironment('yql'),
            environments.PipEnvironment('pymssql', version='2.1.4'),
            environments.PipEnvironment('sqlalchemy'),
        )
        dns = DnsType.DNS64  # for external interactions

    def run_yql_query(self, query, token, db):
        from yql.api.v1.client import YqlClient
        logging.info("start yql query")
        client = YqlClient(db=db, token=token)
        query = client.query(query, syntax_version=1)
        query.run()
        query.wait_progress()

        if not query.is_success:
            logging.info('\n'.join([str(err) for err in query.errors]))
            raise Exception()
        return query

    def connect_to_mssql(self, mssql_user, mssql_password, mssql_server):
        from sqlalchemy import create_engine
        from sqlalchemy.orm import sessionmaker
        from sqlalchemy.pool import NullPool
        import pymssql  # noqa
        logging.info("Try to connect to mssql server")

        engine = create_engine('mssql+pymssql://{user}:{password}@{host_and_port}/import_data'.format(user=mssql_user, password=mssql_password, host_and_port=mssql_server), poolclass=NullPool)
        Session = sessionmaker(bind=engine, autocommit=False)

        logging.info("Succesfully connected")
        return Session()

    def load_stat_to_mssql(self):
        session = self.connect_to_mssql(self.mssql_user, self.mssql_password, self.mssql_server)
        try:
            session.execute('''CREATE TABLE [#page_antifraud_stat](
                                [pageid] [int] NOT NULL PRIMARY KEY CLUSTERED,
                                [shows_all] [bigint] NOT NULL DEFAULT (0),
                                [clicks_all] [bigint] NOT NULL DEFAULT (0),
                                [shows_fraud_share] [float] NULL,
                                [clicks_fraud_share] [float] NULL,
                                [shows_12h] [bigint] NOT NULL DEFAULT (0),
                                [fraud_by_subnet_shows_12h] [bigint] NOT NULL DEFAULT (0),
                                [median_cookie_age] [bigint] NOT NULL DEFAULT (0)
            )''')
            tmp_cache = []
            for i, row in enumerate(self.antifraud_stat):
                tmp_cache.append(
                    "(" +
                    str(row['pageid']) + ", " +
                    str(row['shows_all']) + ", " +
                    str(row['clicks_all']) + ", " +
                    (str(row['shows_fraud_share']) if str(row['shows_fraud_share']) != 'nan' else 'NULL') + ", " +
                    (str(row['clicks_fraud_share']) if str(row['clicks_fraud_share']) != 'nan' else 'NULL') + ", " +
                    str(row['shows_12h']) + "," +
                    str(row['fraud_by_subnet_shows_12h']) + "," +
                    str(row['median_cookie_age']) + ")"
                )

                if i > 0 and i % 999 == 0:
                    session.execute(
                        "INSERT INTO [#page_antifraud_stat] (pageid, shows_all, clicks_all, shows_fraud_share, clicks_fraud_share, shows_12h, fraud_by_subnet_shows_12h, median_cookie_age) VALUES " +
                        ", ".join(tmp_cache))
                    tmp_cache = []

            if len(tmp_cache) > 0:
                session.execute(
                    "INSERT INTO [#page_antifraud_stat] (pageid, shows_all, clicks_all, shows_fraud_share, clicks_fraud_share, shows_12h, fraud_by_subnet_shows_12h, median_cookie_age) VALUES " +
                    ", ".join(tmp_cache))
            session.execute(
                '''MERGE [page_antifraud_stat] WITH(TABLOCK) as trg
                USING (SELECT pageid, shows_all, clicks_all, shows_fraud_share, clicks_fraud_share, shows_12h, fraud_by_subnet_shows_12h, median_cookie_age FROM [#page_antifraud_stat]) as src
                ON trg.pageid=src.pageid
                WHEN MATCHED THEN
                UPDATE SET
                    trg.shows_all=src.shows_all,
                    trg.clicks_all=src.clicks_all,
                    trg.shows_fraud_share=src.shows_fraud_share,
                    trg.clicks_fraud_share=src.clicks_fraud_share,
                    trg.shows_12h=src.shows_12h,
                    trg.fraud_by_subnet_shows_12h=src.fraud_by_subnet_shows_12h,
                    trg.median_cookie_age=src.median_cookie_age
                WHEN NOT MATCHED BY TARGET
                    THEN INSERT (pageid, shows_all, clicks_all, shows_fraud_share, clicks_fraud_share, shows_12h, fraud_by_subnet_shows_12h, median_cookie_age)
                    VALUES (src.pageid, src.shows_all, src.clicks_all, src.shows_fraud_share, src.clicks_fraud_share, src.shows_12h, src.fraud_by_subnet_shows_12h, src.median_cookie_age)
                WHEN NOT MATCHED BY SOURCE
                    THEN DELETE;''')
            session.commit()
            logging.info("Successfully insert {} rows".format(len(self.antifraud_stat)))
        except:
            session.rollback()
            raise
        finally:
            session.close()

    def find_last_table(self, dir_path):
        logging.info("Try to find fresh table in dir {}".format(dir_path))

        self.query = '''
        select
            Path
        from folder(`{0}`)
        where
            Type='table'
        order by Path desc
        limit 4
        '''.format(dir_path)

        yql_result = self.run_yql_query(self.query, self.yt_token, self.Parameters.yt_cluster)
        max_datetime = datetime.datetime(1900, 1, 1, 0, 0, 0)
        self.antifraud_pageid_table = ''
        self.antifraud_pageid_tables_12h = []
        for table in yql_result.get_results():
            table.fetch_full_data()
            for row in table.rows:
                self.antifraud_pageid_tables_12h += [row[0]]
                table_time = row[0].split('/')[-1].split('T')
                table_time = table_time[0].split('-') + table_time[1].split(':')
                if len(table_time) != 6:
                    continue
                table_time = [int(val) for val in table_time]
                table_datetime = datetime.datetime(table_time[0], table_time[1], table_time[2], table_time[3], table_time[4], table_time[5])
                if table_datetime > max_datetime:
                    max_datetime = table_datetime
                    self.antifraud_pageid_table = row[0]

        if self.antifraud_pageid_table == '':
            logging.info('dont find any table')
            raise Exception()
        logging.info('find the newest table {}'.format(self.antifraud_pageid_table))

    def get_stats_12h(self):
        dates = [x.split('/')[-1] for x in self.antifraud_pageid_tables_12h]
        self.query = '''
$script = @@
import socket
import struct

def ipv6_to_int(addr):
    hi, lo = struct.unpack('!QQ', socket.inet_pton(socket.AF_INET6, addr))
    return (hi << 64) | lo

def int_to_ipv6(addr):
    return socket.inet_ntop(socket.AF_INET6, struct.pack('!QQ', addr >> 64, addr & ((1 << 64) - 1)))

def is_ipv4(addr):
    mask1 = ((1 << 16) - 1) << 32
    mask2 = (1 << 32) - 1
    return addr - (addr & mask2) == mask1

def get_subnet(addr):
    try:
        i = ipv6_to_int(addr)
        mask_ipv6 = (1 << 64) - 1
        mask_ipv4 = (1 << 8) - 1
        if is_ipv4(i):
            i -= i & mask_ipv4
        elif i > 0:
            i -= i & mask_ipv6
        return int_to_ipv6(i)
    except:
        return addr
@@;
$get_subnet = Python2::get_subnet(Callable<(String?)->String?>, $script);


DEFINE ACTION $calc($date) AS
    $table_path = '//home/antifraud/export/awaps/subnet/' || CAST($date AS String);
    $format = DateTime::Format("%Y-%m-%dT%H:%M:%S");
    $parse = DateTime::Parse("%Y-%m-%dT%H:%M:%S");
    $date_dt = $parse($date);
    $date_end = $format(DateTime::Update($date_dt,
                                         cast(DateTime::GetHour($date_dt) + 2 as Uint8) as Hour,
                                         cast(DateTime::GetMinute($date_dt) + 55 as Uint8) as Minute));
    INSERT INTO @tmp
    SELECT
        $date as time,
        pageid,
        count_if(cast(subnet.shows_fraud_share as double) > 0.09) as bad_shows,
        count(*) as shows
    FROM range(`home/logfeller/logs/awaps-log/stream/5min`, $date, $date_end) as log
    join $table_path as subnet
    on $get_subnet(log.ipv6) == subnet.subnet
    where
        log.actionid = '0'
    group by
        log.rtb_site_id as pageid;
END DEFINE;

EVALUATE FOR $date IN AsList({0}) DO $calc($date);

commit;

select
    pageid as pageid,
    sum(bad_shows) as bad_shows,
    sum(shows) as shows
from @tmp
group by
    pageid;
        '''.format('{0}'.format(dates)[1:-1].replace('u', ''))

        logging.info("run query: {}".format(self.query))

        yql_result = self.run_yql_query(self.query, self.yt_token, self.Parameters.yt_cluster)

        for table in yql_result.get_results():
            table.fetch_full_data()
            logging.info('fetched {} rows', len(table.rows))
            for row in table.rows:
                if (row[0] is None or row[1] is None or row[2] is None):
                    continue
                self.stats_12h[row[0]] = {
                    'fraud_by_subnet_shows_12h': row[1],
                    'shows_12h': row[2],
                }

        return self.stats_12h

    def _get_log_datetime(self, dt):
        return dt.strftime('%Y-%m-%dT%H:') + str(dt.minute - dt.minute % 5) + ":00"

    def get_median_cookie_age(self):
        begin = self._get_log_datetime(datetime.datetime.now() - datetime.timedelta(hours=12))
        end = self._get_log_datetime(datetime.datetime.now())
        self.query = '''
SELECT
    rtb_site_id,
    cast(MEDIAN(value) as uint64) as mid
from (SELECT
        rtb_site_id,
        cast(unixtime as uint64) - (cast(yandexuid as uint64) % 10000000000ul) as value
    FROM range(`home/logfeller/logs/awaps-log/stream/5min`, `{0}`, `{1}`)
    where
        actionid = '0' and
        rtb_site_id != '' and
        yandexuid != '' and
        cast(unixtime as uint64) > (cast(yandexuid as uint64) % 10000000000ul)
    )
group by
    rtb_site_id
        '''.format(begin, end)
        logging.info("run query: {}".format(self.query))

        yql_result = self.run_yql_query(self.query, self.yt_token, self.Parameters.yt_cluster)

        for table in yql_result.get_results():
            table.fetch_full_data()
            logging.info('fetched {} rows', len(table.rows))
            for row in table.rows:
                if (row[0] is None or row[1] is None):
                    continue
                if row[0] not in self.stats_12h:
                    self.stats_12h[row[0]] = {}
                self.stats_12h[row[0]]['median_cookie_age'] = row[1]

        return self.stats_12h

    def get_antifraud_stat(self):
        logging.info("Try to read antifraud stat from table {}".format(self.antifraud_pageid_table))

        self.query = '''
        select
            cast(pageid as String),
            cast(shows_fraud_share as String),
            cast(clicks_fraud_share as String),
            cast(shows_all as String),
            cast(clicks_all as String)
        from `{0}`
        '''.format(self.antifraud_pageid_table)

        yql_result = self.run_yql_query(self.query, self.yt_token, self.Parameters.yt_cluster)

        self.stats_12h = {}
        self.get_stats_12h()
        self.get_median_cookie_age()

        self.antifraud_stat = []
        for table in yql_result.get_results():
            table.fetch_full_data()
            for row in table.rows:
                if (row[0] is None or row[1] is None or row[2] is None or row[3] is None or row[4] is None or row[0] == "nan" or row[3] == "nan" or row[4] == "nan"):
                    continue
                page_id = row[0]
                self.antifraud_stat.append({
                    'pageid': int(page_id),
                    'shows_fraud_share': float(row[1]),
                    'clicks_fraud_share': float(row[2]),
                    'shows_all': int(row[3]),
                    'clicks_all': int(row[4]),
                    'shows_12h': self.stats_12h.get(page_id, {}).get('shows_12h', 0),
                    'fraud_by_subnet_shows_12h': self.stats_12h.get(page_id, {}).get('fraud_by_subnet_shows_12h', 0),
                    'median_cookie_age': self.stats_12h.get(page_id, {}).get('median_cookie_age', 0),
                })
        logging.info("finish yql success, collect {} rows".format(len(self.antifraud_stat)))

    def on_execute(self):
        self.yt_token = sdk2.task.Vault.data(self.author, self.Parameters.yt_token_vault_name)
        self.mssql_user = self.Parameters.mssql_user
        self.mssql_password = sdk2.task.Vault.data(self.author, self.Parameters.mssql_password)
        self.mssql_server = self.Parameters.mssql_server

        self.find_last_table(self.Parameters.antifraud_pageid_dir_path)
        self.get_antifraud_stat()
        self.load_stat_to_mssql()
