# -*- coding: utf-8 -*-
import hashlib
import os
import textwrap
from datetime import datetime, timedelta
from StringIO import StringIO

import sandbox.common.types.task as ctt
from sandbox import sdk2, common
from sandbox.projects.common import binary_task
from sandbox.projects.yabs.avp.YabsAvpMetricaExporter import YabsAvpMetricaExporter
from sandbox.sandboxsdk import errors

TELEGRAM_RECIPIENTS = ["aybekko97", "bahbka"]

SANDBOX_TASK_URL_TEMPLATE = "https://sandbox.yandex-team.ru/task/{}"
TIMEOUT_IN_SECONDS = 6 * 60 * 60

DEFAULT_AWS_ENDPOINT_URL = "http://s3.mds.yandex.net"
DEFAULT_AWS_BUCKET_NAME = "mediascope-bucket"

DEFAULT_YT_PROXY = "hahn"
DEFAULT_YT_METRICA_DIR = "//home/yabs-rt/ads-verify/mediascope-metrica"
DEFAULT_YT_TABLE_TTL = 15

DEFAULT_AWS_SECRET = "sec-01fcws88d1vbmwzxsj69avs985"  # alias: avp-mediascope
DEFAULT_YAV_SECRET = "sec-01dhp5zhknswkzjk3grmr0rebq"  # alias: robot-yabs-avp

METRICA_PARTS = ["appmetrica", "appmetrica_zen", "appmetrica_ifv", "appmetrica_ifv_zen", "metrica", "metrica_zen"]


class YabsAvpMetricaExporterScheduler(binary_task.LastBinaryTaskRelease, sdk2.Task):
    """ YABS_AVP_METRICA_EXPORTER_SCHEDULER """

    class Parameters(sdk2.Parameters):

        with sdk2.parameters.Group('Export parameters') as export_params:
            parts = sdk2.parameters.List("Metric parts", default=METRICA_PARTS, description="Metrica parts to export", required=True)
            take_yesterdays_date = sdk2.parameters.Bool("Take yesterday's date", default=True, required=True)
            with take_yesterdays_date.value[False]:
                export_date = sdk2.parameters.String("Export date (format: 'YYYY-mm-dd')", required=True)

        with sdk2.parameters.Group('YT & YQL parameters') as yt_yql_params:
            yt_proxy = sdk2.parameters.String("YT cluster to store metric tables", default=DEFAULT_YT_PROXY, required=True)
            yt_dir = sdk2.parameters.String("YT directory to store metric tables", default=DEFAULT_YT_METRICA_DIR, required=True)
            yt_ttl = sdk2.parameters.Integer("TTL for metric tables in YT", default=DEFAULT_YT_TABLE_TTL, required=True)

        with sdk2.parameters.Group('AWS S3 parameters') as aws_s3_params:
            aws_endpoint_url = sdk2.parameters.String("AWS endpoint url", default=DEFAULT_AWS_ENDPOINT_URL, required=True)
            aws_bucket_name = sdk2.parameters.String("AWS bucket name", default=DEFAULT_AWS_BUCKET_NAME, required=True)

        with sdk2.parameters.Group('SFTP parameters') as sftp_params:
            sftp_host = sdk2.parameters.String("SFTP host")
            sftp_user = sdk2.parameters.String("SFTP user")
            sftp_port = sdk2.parameters.String("SFTP port", default=22)
            sftp_path = sdk2.parameters.String("SFTP path", default="./")

        with sdk2.parameters.Group('Run') as run:
            env_vars = sdk2.parameters.Dict(
                'env_vars',
                description="Environment variables (e.g. key value).\nMay be used with Yav Secret: sec-xxx:key"
            )

        ext_params = binary_task.binary_release_parameters(stable=True)

    def get_current_aws_filenames(self):
        import boto3
        session = boto3.session.Session(
            aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
            aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"]
        )
        s3 = session.client(
            service_name='s3',
            endpoint_url=str(self.Parameters.aws_endpoint_url)
        )
        response = s3.list_objects(Bucket=self.Parameters.aws_bucket_name)

        return [obj["Key"] for obj in response["Contents"]]

    def get_current_sftp_filenames(self, date_dir):
        import paramiko

        private_key_str = os.environ["SFTP_PKEY"]
        pkey = paramiko.RSAKey.from_private_key(StringIO(private_key_str))
        with paramiko.SSHClient() as client:
            client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            client.connect(hostname=self.Parameters.sftp_host, username=self.Parameters.sftp_user, pkey=pkey,
                           look_for_keys=False, allow_agent=False)
            sftp = client.open_sftp()
            sftp.chdir(self.Parameters.sftp_path)
            if date_dir not in sftp.listdir():
                sftp.mkdir(date_dir)

            return sftp.listdir(path=date_dir)

    def get_subtask(self, part, export_date, pre_export):
        return YabsAvpMetricaExporter(
            self,
            description="part - {}, export date - {}".format(part, export_date),
            owner=self.owner,
            parts=[part],
            pre_export=pre_export,
            export_date=export_date,
            yt_proxy=self.Parameters.yt_proxy,
            yt_dir=self.Parameters.yt_dir,
            yt_ttl=self.Parameters.yt_ttl,
            aws_endpoint_url=self.Parameters.aws_endpoint_url,
            aws_bucket_name=self.Parameters.aws_bucket_name,
            sftp_host=self.Parameters.sftp_host,
            sftp_user=self.Parameters.sftp_user,
            sftp_port=self.Parameters.sftp_port,
            sftp_path=self.Parameters.sftp_path,
            env_vars=self.Parameters.env_vars,
            kill_timeout=TIMEOUT_IN_SECONDS
        )

    def prepare_env_vars(self):
        for k, v in self.Parameters.env_vars.items():
            if v.startswith('sec-'):
                ver, var = v.split(':', 1)
                v = sdk2.yav.Secret(ver).data()[var]
            os.environ[k] = v

    def on_execute(self):
        self.prepare_env_vars()

        if self.Parameters.take_yesterdays_date:
            date = datetime.now() - timedelta(days=1)
            export_date = date.strftime("%Y-%m-%d")
        else:
            export_date = self.Parameters.export_date
            date = datetime.strptime(export_date, "%Y-%m-%d")

        salt1 = date.strftime("%Y%m")
        salt2 = (date + timedelta(days=10)).strftime("%Y%m")

        with self.memoize_stage.create_children:
            current_aws_filenames = self.get_current_aws_filenames()
            current_sftp_filenames = self.get_current_sftp_filenames(date_dir=export_date)

            subtasks = []
            for part in self.Parameters.parts:
                table_name = "{}_{}".format(part, export_date)
                aws_filename = hashlib.sha512(table_name.encode()).hexdigest()
                sftp_filename = "{}_main_{}.ndjson.gz".format(part, export_date.replace('-', ''))

                if aws_filename not in current_aws_filenames or sftp_filename not in current_sftp_filenames:
                    task = self.get_subtask(part=part, export_date=export_date, pre_export=False)
                    task.save().enqueue()
                    subtasks.append(task)

                # 10 days before the end of the month, we start uploading logs with a sample from the next month
                if salt1 != salt2:
                    table_name = "{}2_{}".format(part, export_date)
                    aws_filename = hashlib.sha512(table_name.encode()).hexdigest()
                    sftp_filename = "{}_add_{}.ndjson.gz".format(part, export_date.replace('-', ''))
                    if aws_filename not in current_aws_filenames or sftp_filename not in current_sftp_filenames:
                        task = self.get_subtask(part=part, export_date=export_date, pre_export=True)
                        task.save().enqueue()
                        subtasks.append(task)
            raise sdk2.WaitTask(subtasks, (ctt.Status.Group.FINISH, ctt.Status.Group.BREAK), timeout=TIMEOUT_IN_SECONDS)

        child_tasks = self.find()

        failed_task_ids = []
        for task in child_tasks:
            if task.status not in ctt.Status.Group.SUCCEED:
                failed_task_ids.append(task.id)

        if failed_task_ids:
            if datetime.now().hour >= 10:  # send notification in daytime only
                self.server.notification(
                    body=textwrap.dedent("""\
                        [MEDIASCOPE] YabsAvpMetricaExporterScheduler
                        Task failed: {}
                        will retry after 1 hour...""".format(SANDBOX_TASK_URL_TEMPLATE.format(self.id))),
                    recipients=TELEGRAM_RECIPIENTS,
                    transport=common.types.notification.Transport.TELEGRAM
                )

            raise errors.SandboxTaskFailureError("Child task has failed.")
