#!/usr/bin/env python
# -*- coding: utf-8 -*-
from datetime import datetime, timedelta
from sandbox import sdk2
from sandbox.common.types.misc import DnsType
from sandbox.sandboxsdk import environments
import logging
import socket
import struct


class YabsAwapsAntifraudExportSubnet(sdk2.Task):
    class Parameters(sdk2.Task.Parameters):
        yt_token_vault_name = sdk2.parameters.String("YQL robot token vault name", default="awaps_robot_yt_token", required=True)
        mssql_password = sdk2.parameters.String("MSSQL awaps user password", default="yt_loader_password", required=True)
        mssql_user = sdk2.parameters.String("MSSQL awaps user", default="yt_loader", required=True)
        mssql_server = sdk2.parameters.String("MSSQL server:port", default="sqllogc:1433", required=True)
        antifraud_subnet_dir_path = sdk2.parameters.String("Antifraud AwapsStat by Subnet Dir Path", default="//home/antifraud/export/awaps/subnet", required=True)
        yt_cluster = sdk2.parameters.String("Yt cluster", default='hahn', required=True)

    class Requirements(sdk2.Task.Requirements):
        environments = (
            environments.PipEnvironment('yandex-yt'),
            environments.PipEnvironment('yql'),
            environments.PipEnvironment('pymssql', version='2.1.4'),
            environments.PipEnvironment('sqlalchemy'),
        )
        dns = DnsType.DNS64  # for external interactions

    def get_process_time(self, frmt='%Y-%m-%d', days=1):
        start_time = datetime.now() - timedelta(days=days)
        return start_time.strftime(frmt)

    def run_yql_query(self, query, token, db):
        from yql.api.v1.client import YqlClient
        logging.info("start yql query")
        client = YqlClient(db=db, token=token)
        query = client.query(query, syntax_version=1)
        query.run()
        query.wait_progress()

        if not query.is_success:
            logging.info('\n'.join([str(err) for err in query.errors]))
            raise Exception()
        return query

    def connect_to_mssql(self, mssql_user, mssql_password, mssql_server):
        from sqlalchemy import create_engine
        from sqlalchemy.orm import sessionmaker
        from sqlalchemy.pool import NullPool
        import pymssql  # noqa
        logging.info("Try to connect to mssql server")

        engine = create_engine('mssql+pymssql://{user}:{password}@{host_and_port}/import_data'.format(user=mssql_user, password=mssql_password, host_and_port=mssql_server), poolclass=NullPool)
        Session = sessionmaker(bind=engine, autocommit=False)

        logging.info("Succesfully connected")
        return Session()

    def _ipv6_to_int(self, addr):
        hi, lo = struct.unpack('!QQ', socket.inet_pton(socket.AF_INET6, addr))
        return (hi << 64) | lo

    def _int_to_ipv6(self, addr):
        return socket.inet_ntop(socket.AF_INET6, struct.pack('!QQ', addr >> 64, addr & ((1 << 64) - 1)))

    def _get_subnet_end(self, addr):
        try:
            i = self._ipv6_to_int(addr)
            mask_ipv6 = (1 << 64) - 1
            mask_ipv4 = (1 << 8) - 1
            if i & mask_ipv6:
                i |= mask_ipv4
            elif i > 0:
                i |= mask_ipv6
            return self._int_to_ipv6(i)
        except:
            return addr

    def load_stat_to_mssql(self):
        session = self.connect_to_mssql(self.mssql_user, self.mssql_password, self.mssql_server)
        try:
            session.execute('''CREATE TABLE [#subnet_antifraud_stat](
                                  [subnet] [varchar](255) NOT NULL PRIMARY KEY CLUSTERED,
                                  [subnet_end] [varchar](255) NOT NULL,
                                  [shows_all] [bigint] NOT NULL DEFAULT (0),
                                  [clicks_all] [bigint] NOT NULL DEFAULT (0),
                                  [shows_fraud_share] [float] NULL,
                                  [clicks_fraud_share] [float] NULL
                                )
            ''')
            logging.info("created table [#subnet_antifraud_stat]")
            tmp_cache = []
            for i, row in enumerate(self.antifraud_stat):
                tmp_cache.append(
                    "('" +
                    row['subnet'] + "', " +
                    "'" + self._get_subnet_end(row['subnet']) + "', " +
                    str(row['shows_all']) + ", " +
                    str(row['clicks_all']) + ", " +
                    (str(row['shows_fraud_share']) if str(row['shows_fraud_share']) != 'nan' else 'NULL') + ", " +
                    (str(row['clicks_fraud_share']) if str(row['clicks_fraud_share']) != 'nan' else 'NULL') + ")"
                )

                if i > 0 and i % 999 == 0:
                    session.execute("INSERT INTO [#subnet_antifraud_stat] (subnet, subnet_end, shows_all, clicks_all, shows_fraud_share, clicks_fraud_share) VALUES " + ", ".join(tmp_cache))
                    tmp_cache = []

            if len(tmp_cache) > 0:
                session.execute("INSERT INTO [#subnet_antifraud_stat] (subnet, subnet_end, shows_all, clicks_all, shows_fraud_share, clicks_fraud_share) VALUES " + ", ".join(tmp_cache))
            session.execute('''MERGE [subnet_antifraud_stat] WITH(TABLOCK) as trg
                               USING (SELECT subnet, subnet_end, shows_all, clicks_all, shows_fraud_share, clicks_fraud_share FROM [#subnet_antifraud_stat]) as src
                               ON trg.subnet=src.subnet
                               WHEN MATCHED THEN
                               UPDATE SET
                                   trg.subnet_end=src.subnet_end,
                                   trg.shows_all=src.shows_all,
                                   trg.clicks_all=src.clicks_all,
                                   trg.shows_fraud_share=src.shows_fraud_share,
                                   trg.clicks_fraud_share=src.clicks_fraud_share
                               WHEN NOT MATCHED BY TARGET
                                   THEN INSERT (subnet, subnet_end, shows_all, clicks_all, shows_fraud_share, clicks_fraud_share)
                                   VALUES (src.subnet, src.subnet_end, src.shows_all, src.clicks_all, src.shows_fraud_share, src.clicks_fraud_share)
                               WHEN NOT MATCHED BY SOURCE
                                   THEN DELETE;
            ''')
            session.commit()
            logging.info("Successfully insert {} rows".format(len(self.antifraud_stat)))
        except:
            session.rollback()
            raise
        finally:
            session.close()

    def find_last_table(self, dir_path):
        logging.info("Try to find fresh table in dir {}".format(dir_path))

        self.query = '''
        select
            Path
        from folder(`{0}`)
        where
            Type='table'
        order by
            Path desc
        limit 1
        '''.format(dir_path)

        yql_result = self.run_yql_query(self.query, self.yt_token, self.Parameters.yt_cluster)
        for table in yql_result.get_results():
            table.fetch_full_data()
            for row in table.rows:
                logging.info('found the newest table {}'.format(row[0]))
                return row[0]
        logging.info('did not found any table')
        raise Exception()

    def get_antifraud_stat(self, input_table):
        logging.info("Try to read antifraud stat from table {}".format(input_table))

        self.query = '''
        select
            subnet,
            cast(shows_fraud_share as String),
            cast(clicks_fraud_share as String),
            cast(shows_all as String),
            cast(clicks_all as String)
        from `{0}`
        '''.format(input_table)

        yql_result = self.run_yql_query(self.query, self.yt_token, self.Parameters.yt_cluster)

        self.antifraud_stat = []
        for table in yql_result.get_results():
            table.fetch_full_data()
            for row in table.rows:
                if (row[0] is None or row[1] is None or row[2] is None or row[3] is None or row[4] is None or row[0] == "nan" or row[3] == "nan" or row[4] == "nan"):
                    continue
                self.antifraud_stat.append({
                    'subnet': row[0],
                    'shows_fraud_share': float(row[1]),
                    'clicks_fraud_share': float(row[2]),
                    'shows_all': int(row[3]),
                    'clicks_all': int(row[4]),
                })
        logging.info("finish yql success, collect {} rows".format(len(self.antifraud_stat)))

    def on_execute(self):
        self.yt_token = sdk2.task.Vault.data(self.author, self.Parameters.yt_token_vault_name)
        self.mssql_user = self.Parameters.mssql_user
        self.mssql_password = sdk2.task.Vault.data(self.author, self.Parameters.mssql_password)
        self.mssql_server = self.Parameters.mssql_server

        self.antifraud_subnet_table = self.find_last_table(self.Parameters.antifraud_subnet_dir_path)
        self.get_antifraud_stat(self.antifraud_subnet_table)
        self.load_stat_to_mssql()
