import logging
import hashlib
import yt.wrapper as yt
import boto3

from yql.api.v1.client import YqlClient
from datetime import datetime, timedelta


BID_CID_QUERY = """, $get_property(Data, "/CID") AS CampaignID
    , $get_property(Data, "/BID") AS BannerID"""
MSID_QUERY = """, $get_property(Data, "/msid") AS msid"""


class VerifyLogsHandler():
    def __init__(self, platform_id, mds_secret_key, mds_access_key_id, yt_token, yql_token):
        self.platform_id = int(platform_id)
        self._set_platform_properies()

        session = boto3.session.Session(
            aws_access_key_id=mds_access_key_id,
            aws_secret_access_key=mds_secret_key,
        )
        self.s3 = session.client(
            service_name='s3',
            endpoint_url=self.mds_host,
        )

        self.ytc = yt.YtClient(proxy=self.yt_cluster, token=yt_token)

        self.yql_client = YqlClient(
            db=self.yt_cluster,
            token=yql_token,
        )

        self.now = datetime.now()

    def _set_platform_properies(self):
        if self.platform_id == 1:  # mediascope
            self.yt_logs_path = "//home/yabs-rt/ads-verify/mediascope"
            self.yt_cluster = "hahn"
            self.table_prefix = "AdsVerify_"
            self.bucket_name = "mediascope-bucket"
            self.additional_rows = MSID_QUERY
            self.mds_host = "https://s3.mds.yandex.net"
            self.count_to_save = 20
            self.days_never_remove = 5
        else:
            raise ValueError("Unknown platform id: '%d'" % self.platform_id)

    def create_new_logs(self):
        unprocessed_date = self._get_first_unprocessed_day()

        logging.info("Last unprocessed date: '%s'." % str(unprocessed_date))

        yesterday = self.now.date() - timedelta(days=1)

        while unprocessed_date <= yesterday:
            logging.info("Create log for date: %s" % unprocessed_date)
            self._create_verify_log(unprocessed_date)
            unprocessed_date += timedelta(days=1)

    def _get_first_unprocessed_day(self):
        all_logs = self.ytc.list(self.yt_logs_path)
        ads_verify_logs = [x[len(self.table_prefix):] for x in all_logs if x.find(self.table_prefix) == 0]
        ads_verify_logs.sort(reverse=True)

        if len(ads_verify_logs) > 0:
            return (datetime.strptime(ads_verify_logs[0], "%Y-%m-%d") + timedelta(days=1)).date()
        else:
            return (self.now - timedelta(days=1)).date()

    def _create_verify_log(self, curr_date):
        log_name = self._get_log_name(curr_date)
        log_path = self.yt_logs_path + "/" + log_name
        request = self.yql_client.query(
            """
            $get_property = ($data, $property_name) -> {{ return (Yson::ConvertToString(
                    Yson::YPath(
                        Yson::ParseJson($data, Yson::Options(false AS Strict)), $property_name, Yson::Options(false AS Strict)
                    )
                ))
            }};

            $get_location = ($data) -> {{ return Url::CutQueryStringAndFragment($get_property($data, "/REF")); }};


            INSERT INTO `{log_path}`
            SELECT
                ServerTimestamp AS EventTime
                , COALESCE(Cast(UserAgent::Parse(UserAgent).BrowserName AS String), "")  || " "
                    || COALESCE(Cast(UserAgent::Parse(UserAgent).BrowserVersion AS String), "") || " "
                    || COALESCE(Cast(UserAgent::Parse(UserAgent).OSFamily AS String), "") AS UserAgent
                , if (Ip::IsIPv6(Ip::FromString(ClientIP)), "IPv6", "IPv4") AS IPVer
                , if (Ip::IsIPv6(Ip::FromString(ClientIP)),
                    Ip::ToString(Ip::GetSubnet(Ip::FromString(ClientIP), 48)),
                    Ip::ToString(Ip::GetSubnet(Ip::FromString(ClientIP), 24))
                ) AS ClientIP
                , if(page.OptionsApp and page.OptionsMobile,
                   Url::GetHost($get_location(Data)),
                   $get_location(Data),
                ) AS Location
                , $get_property(Data, "/DTYPE") AS DeviceType
                , if ($get_property(Data, "/DRND") != null, $get_property(Data, "/DRND"), String::SplitToList(Host, ".")[0]) AS RND
                , $get_property(Data, "/customdata") AS Customdata
                , $get_property(Data, "/BTYPE") AS AdParams
                , $get_property(Data, "/SESSION") AS SessionID
                , $get_property(Data, "/CID") AS CampaignID
                , $get_property(Data, "/BID") AS BannerID
                {select_end}
            FROM `logs/bs-proto-verify-log/1d/{date}` AS verify
            LEFT JOIN `home/yabs/dict/Page` AS page
                ON page.PageID = CAST($get_property(verify.Data, "/page") AS Int64)
            WHERE
                $get_property(Data, "/platformid") == "{platformid}"
            """.format(log_path=log_path, date=curr_date.strftime("%Y-%m-%d"), platformid=self.platform_id, select_end=self.additional_rows),
            syntax_version=1
        )
        request.run()

        if not request.get_results().is_success:
            error_description = '\n'.join([str(err) for err in request.get_results().errors])
            logging.error(error_description)
            raise RuntimeError(error_description)

    def _get_log_name(self, curr_date):
        return self.table_prefix + curr_date.strftime("%Y-%m-%d")

    def copy_logs_to_mds_and_clear_mds(self):
        yt_logs = self._get_last_logs_from_yt()
        mds_logs = self._get_logs_from_mds()
        dict_log_name_sha = {log_name: hashlib.sha512(log_name).hexdigest() for log_name in yt_logs}
        self._add_new_logs(dict_log_name_sha, list(mds_logs.keys()))
        self._remove_old_logs(dict_log_name_sha, mds_logs)

    def _get_last_logs_from_yt(self):
        all_logs = self.ytc.list(self.yt_logs_path)
        exsist_logs = []
        curr_date = (self.now - timedelta(days=self.count_to_save + 1)).date()
        while self.now.date() > curr_date:
            log_name = self._get_log_name(curr_date)
            if log_name in all_logs:
                exsist_logs.append(log_name)
            curr_date += timedelta(days=1)
        return exsist_logs

    def _get_logs_from_mds(self):
        response = self.s3.list_objects(Bucket=self.bucket_name)
        return {obj["Key"]: obj["LastModified"] for obj in response["Contents"]}

    def _add_new_logs(self, dict_log_name_sha, mds_logs):
        logs_to_add = []
        for log_name, name_sha in dict_log_name_sha.iteritems():
            if name_sha not in mds_logs:
                logs_to_add.append((log_name, name_sha))

        for log_name, name_sha in logs_to_add:
            logging.info("Copy '%s' to mds" % log_name)
            self._add_log_to_mds(self.yt_logs_path + "/" + log_name, name_sha)

    def _add_log_to_mds(self, yt_path, key):
        table = self.ytc.read_table(yt_path, format="json", raw=True)
        rows_count = self.ytc.row_count(yt_path)
        part_size = max((rows_count / 10000) + 1, 2000000)  # 10000 is maximum parts count in multipart upload
        mpu = self.s3.create_multipart_upload(Bucket=self.bucket_name, Key=key)
        try:
            upload_id = mpu["UploadId"]
            part_number = 1
            part = []
            parts_configs = []
            for i, row in enumerate(table):
                if len(part) < part_size:
                    part.append(row)
                else:
                    part_config = self.s3.upload_part(Body="".join(part), Bucket=self.bucket_name, Key=key, UploadId=upload_id, PartNumber=part_number)
                    parts_configs.append({"PartNumber": part_number, "ETag": part_config["ETag"]})
                    part = []
                    part_number += 1

            if len(part) != 0:
                part_config = self.s3.upload_part(Body="".join(part), Bucket=self.bucket_name, Key=key, UploadId=upload_id, PartNumber=part_number)
                parts_configs.append({"PartNumber": part_number, "ETag": part_config["ETag"]})
            if len(parts_configs) > 0:
                self.s3.complete_multipart_upload(Bucket=self.bucket_name, Key=key, UploadId=upload_id, MultipartUpload={"Parts": parts_configs})
            else:
                self.s3.abort_multipart_upload(Bucket=self.bucket_name, Key=key, UploadId=upload_id)
        except:
            self.s3.abort_multipart_upload(Bucket=self.bucket_name, Key=key, UploadId=upload_id)
            raise

    def _remove_old_logs(self, dict_log_name_sha, mds_logs):
        sha_log_names = {v: k for k, v in dict_log_name_sha.iteritems()}
        min_log_date = (self.now - timedelta(days=self.days_never_remove)).date()
        for sha_name, last_modified in mds_logs.iteritems():
            if sha_name not in sha_log_names and last_modified.date() < min_log_date:
                logging.info("Delete old log: '%s'" % sha_name)
                self.s3.delete_object(Bucket=self.bucket_name, Key=sha_name)
